diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 00000000..c6294d30 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,31 @@ +[run] +source = + splitio/ + +omit = + tests/* + */__init__.py + +branch = True + +relative_files = True + +[report] +# Regexes for lines to exclude from consideration +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain about missing debug-only code: + def __repr__ + if self\.debug + + # Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + + # Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: + +precision = 2 diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..9e319810 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @splitio/sdk diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 1b64b3e5..95efd4c7 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -4,4 +4,4 @@ ## How do we test the changes introduced in this PR? -## Extra Notes \ No newline at end of file +## Extra Notes diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 10946afe..df28cd54 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,14 +3,20 @@ on: push: branches: - master + - development pull_request: branches: - master - development +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number }} + cancel-in-progress: true + jobs: test: - runs-on: ubuntu-latest + name: Test + runs-on: ubuntu-22.04 services: redis: image: redis @@ -18,52 +24,51 @@ jobs: - 6379:6379 steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v5 with: fetch-depth: 0 - - name: Set up Python - uses: actions/setup-python@v2 + - name: Setup Python + uses: actions/setup-python@v6 with: - python-version: '3.6' + python-version: '3.7.16' - name: Install dependencies run: | - pip install -U setuptools pip + sudo apt update + sudo apt-get install -y libkrb5-dev + pip install -U setuptools pip wheel pip install -e .[cpphash,redis,uwsgi] - name: Run tests run: python setup.py test + - name: Set VERSION env + run: echo "VERSION=$(cat splitio/version.py | grep "__version__" | awk -F\' '{print $2}')" >> $GITHUB_ENV + - name: SonarQube Scan (Push) if: github.event_name == 'push' - uses: SonarSource/sonarcloud-github-action@v1.5 + uses: SonarSource/sonarqube-scan-action@v6 env: SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: projectBaseDir: . args: > -Dsonar.host.url=${{ secrets.SONARQUBE_HOST }} - -Dsonar.projectName=${{ github.event.repository.name }} - -Dsonar.projectKey=${{ github.event.repository.name }} - -Dsonar.python.coverage.reportPaths=coverage.xml - -Dsonar.links.ci="https://github.com/splitio/${{ github.event.repository.name }}/actions" - -Dsonar.links.scm="https://github.com/splitio/${{ github.event.repository.name }}" + -Dsonar.projectVersion=${{ env.VERSION }} - name: SonarQube Scan (Pull Request) if: github.event_name == 'pull_request' - uses: SonarSource/sonarcloud-github-action@v1.5 + uses: SonarSource/sonarqube-scan-action@v6 env: SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: projectBaseDir: . args: > -Dsonar.host.url=${{ secrets.SONARQUBE_HOST }} - -Dsonar.projectName=${{ github.event.repository.name }} - -Dsonar.projectKey=${{ github.event.repository.name }} - -Dsonar.python.coverage.reportPaths=coverage.xml - -Dsonar.links.ci="https://github.com/splitio/${{ github.event.repository.name }}/actions" - -Dsonar.links.scm="https://github.com/splitio/${{ github.event.repository.name }}" + -Dsonar.projectVersion=${{ env.VERSION }} -Dsonar.pullrequest.key=${{ github.event.pull_request.number }} -Dsonar.pullrequest.branch=${{ github.event.pull_request.head.ref }} -Dsonar.pullrequest.base=${{ github.event.pull_request.base.ref }} diff --git a/.github/workflows/update-license-year.yml b/.github/workflows/update-license-year.yml index 989caf52..884edbe9 100644 --- a/.github/workflows/update-license-year.yml +++ b/.github/workflows/update-license-year.yml @@ -13,18 +13,18 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: fetch-depth: 0 - + - name: Set Current year run: "echo CURRENT=$(date +%Y) >> $GITHUB_ENV" - + - name: Set Previous Year run: "echo PREVIOUS=$(($CURRENT-1)) >> $GITHUB_ENV" - name: Update LICENSE - uses: jacobtomlinson/gha-find-replace@v2 + uses: jacobtomlinson/gha-find-replace@v3 with: find: ${{ env.PREVIOUS }} replace: ${{ env.CURRENT }} @@ -38,7 +38,7 @@ jobs: git commit -m "Updated License Year" -a - name: Create Pull Request - uses: peter-evans/create-pull-request@v3 + uses: peter-evans/create-pull-request@v5 with: token: ${{ secrets.GITHUB_TOKEN }} title: Update License Year diff --git a/.gitignore b/.gitignore index 31959c04..d2f290a3 100644 --- a/.gitignore +++ b/.gitignore @@ -72,4 +72,7 @@ target/ # vim backup files *.swp -.DS_Store \ No newline at end of file +.DS_Store + +# Sonarqube +.scannerwork diff --git a/CHANGES.txt b/CHANGES.txt index 683a0bdf..0845c52e 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,3 +1,95 @@ +10.6.0 (Jan 28, 2026) +- Fixed non-blocking error when fetching feature flags from redis. +- Added the ability to listen to different events triggered by the SDK. Read more in our docs. + - SDK_UPDATE notify when a flag or user segment has changed + - SDK_READY notify when the SDK is ready to evaluate + +10.5.1 (Oct 15, 2025) +- Added using String only parameter for treatments in FallbackTreatmentConfiguration class. + +10.5.0 (Sep 15, 2025) +- Changed the log level from error to debug when renewing the token for Streaming service in asyncio mode. +- Added new configuration for Fallback Treatments, which allows setting a treatment value and optional config to be returned in place of "control", either globally or by flag. Read more in our docs. +- Deprecated config parameter `redisErrors` as it is removed in redis lib since 6.0.0 version (https://github.com/redis/redis-py/releases/tag/v6.0.0). + +10.4.0 (Aug 4, 2025) +- Added a new optional argument to the client `getTreatment` methods to allow passing additional evaluation options, such as a map of properties to append to the generated impressions sent to Split backend. Read more in our docs. + +10.3.0 (Jun 17, 2025) +- Added support for rule-based segments. These segments determine membership at runtime by evaluating their configured rules against the user attributes provided to the SDK. +- Added support for feature flag prerequisites. This allows customers to define dependency conditions between flags, which are evaluated before any allowlists or targeting rules. + +10.2.0 (Jan 17, 2025) +- Added support for the new impressions tracking toggle available on feature flags, both respecting the setting and including the new field being returned on SplitView type objects. Read more in our docs. + +10.1.0 (Aug 7, 2024) +- Added support for Kerberos authentication in Spnego and Proxy Kerberos server instances. + +10.0.1 (Jun 28, 2024) +- Fixed failure to load lib issue in SDK startup for Python versions higher than or equal to 3.10 + +10.0.0 (Jun 27, 2024) +- Added support for asyncio library +- BREAKING CHANGE: Minimum supported Python version is 3.7.16 + +9.7.0 (May 15, 2024) +- Added support for targeting rules based on semantic versions (https://semver.org/). +- Added the logic to handle correctly when the SDK receives an unsupported Matcher type. + +9.6.2 (Apr 5, 2024) +- Fixed an issue when pushing unique keys tracker data to redis if no keys exist, i.e. get_treatment flavors are not called. + +9.6.1 (Feb 15, 2024) +- Added redisUsername configuration parameter for Redis connection to set the username for accessing redis when not using the default `root` username + +9.6.0 (Nov 3, 2023) +- Added support for Flag Sets on the SDK, which enables grouping feature flags and interacting with the group rather than individually (more details in our documentation): + - Added new variations of the get treatment methods to support evaluating flags in given flag set/s. + - get_treatments_by_flag_set and get_treatments_by_flag_sets + - get_treatments_with_config_by_flag_set and get_treatments_with_config_by_flag_sets +- Added a new optional Split Filter configuration option. This allows the SDK and Split services to only synchronize the flags in the specified flag sets, avoiding unused or unwanted flags from being synced on the SDK instance, bringing all the benefits from a reduced payload. + - Note: Only applicable when the SDK is in charge of the rollout data synchronization. When not applicable, the SDK will log a warning on init. +- Updated the following SDK manager methods to expose flag sets on flag views. +- Removed raising an exception when Telemetry post config data fails, SDK will only log the error. + +9.5.1 (Sep 5, 2023) +- Exclude tests from when building the package +- Fixed exception when fetching telemetry stats if no SSE Feature flags update events are stored + +9.5.0 (Jul 18, 2023) +- Improved streaming architecture implementation to apply feature flag updates from the notification received which is now enhanced, improving efficiency and reliability of the whole update system. + +9.4.2 (May 15, 2023) +- Updated terminology on the SDKs codebase to be more aligned with current standard without causing a breaking change. The core change is the term split for feature flag on things like logs and code documentation comments. +- Added detailed debug logging for redis adapter. +- Fixed setting defaultTreatment to 'control' if it is missing in localhost JSON file. + +9.4.1 (Apr 18, 2023) +- Fixed storing incorrect Telemetry method latency data + +9.4.0 (Mar 1, 2023) +- Added support to use JSON files in localhost mode. +- Updated default periodic telemetry post time to one hour. +- Fixed unhandeled exception in push.manager.py class when SDK is connected to split proxy + +9.3.0 (Jan 30, 2023) +- Updated SDK telemetry storage, metrics and updater to be more effective and send less often. +- Removed deprecated threading.Thread.setDaemon() method. + +9.2.2 (Dec 13, 2022) +- Fixed RedisSenderAdapter instantiation to store mtk keys. + +9.2.1 (Dec 2, 2022) +- Changed redis record type for impressions counts from list using rpush to hashed key using hincrby. +- Apply Timeout Exception when incorrect SDK API Key is used. +- Changed potential initial fetching segment Warning to Debug in logging. + +9.2.0 (Oct 14, 2022) +- Added a new impressions mode for the SDK called NONE , to be used in factory when there is no desire to capture impressions on an SDK factory to feed Split's analytics engine. Running NONE mode, the SDK will only capture unique keys evaluated for a particular feature flag instead of full blown impressions + +9.1.3 (July 25, 2022) +- Fixed synching missed segment(s) after receiving split update + 9.1.2 (April 6, 2022) - Updated pyyaml dependency for vulnerability CVE-2020-14343. @@ -77,7 +169,7 @@ 7.0.1 (Mar 8, 2019) - Updated Splits refreshing rate. - Replaced exception log level to error level. - - Improved validation for apikey. + - Improved validation for sdkkey. 7.0.0 (Feb 21, 2019) - BREAKING CHANGE: Stored Impressions in Queue. diff --git a/CONTRIBUTORS-GUIDE.md b/CONTRIBUTORS-GUIDE.md index 11483a32..befff911 100644 --- a/CONTRIBUTORS-GUIDE.md +++ b/CONTRIBUTORS-GUIDE.md @@ -28,4 +28,4 @@ To run test you need to execute the following commands: # Contact -If you have any other questions or need to contact us directly in a private manner send us a note at sdks@split.io. \ No newline at end of file +If you have any other questions or need to contact us directly in a private manner send us a note at sdks@split.io. diff --git a/LICENSE.txt b/LICENSE.txt index 051b5fd9..0f9e8a59 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,13 +1,169 @@ -Copyright © 2022 Split Software, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + 1. Definitions. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + END OF TERMS AND CONDITIONS + APPENDIX: How to apply the Apache License to your work. + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + Copyright [yyyy] [name of copyright owner] + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/NOTICE.txt b/NOTICE.txt new file mode 100644 index 00000000..7d7d845e --- /dev/null +++ b/NOTICE.txt @@ -0,0 +1,5 @@ +Harness Feature Management JavaScript SDK Copyright 2024-2026 Harness Inc. + +This product includes software developed at Harness Inc. (https://harness.io/). + +This product includes software originally developed by Split Software, Inc. (https://www.split.io/). Copyright 2015-2024 Split Software, Inc. diff --git a/README.md b/README.md index 4a511fd1..5dae06bf 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -# Split Python SDK -[![Build Status](https://api.travis-ci.com/splitio/python-client.svg?branch=master)](https://api.travis-ci.com/splitio/python-client) +# Split Python SDK +![Build Status](https://github.com/splitio/python-client/actions/workflows/ci.yml/badge.svg?branch=master) ## Overview This SDK is designed to work with Split, the platform for controlled rollouts, which serves features to your users via a Split feature flag to manage your complete customer experience. @@ -7,7 +7,7 @@ This SDK is designed to work with Split, the platform for controlled rollouts, w [![Twitter Follow](https://img.shields.io/twitter/follow/splitsoftware.svg?style=social&label=Follow&maxAge=1529000)](https://twitter.com/intent/follow?screen_name=splitsoftware) ## Compatibility -This SDK is compatible with **Python 3 and higher**. +This SDK is compatible with **Python 3.7 and higher**. ## Getting started Below is a simple example that describes the instantiation and most basic usage of our SDK: @@ -22,8 +22,8 @@ factory = get_factory('YOUR_SDK_TYPE_API_KEY', config=config) try: factory.block_until_ready(5) # wait up to 5 seconds split = factory.client() - treatment = split.get_treatment('CUSTOMER_ID', 'SPLIT_NAME') - if treatment == "on": + treatment = split.get_treatment('CUSTOMER_ID', 'FEATURE_FLAG_NAME') + if treatment == "on": # insert code here to show on treatment elif treatment == "off": # insert code here to show off treatment @@ -54,16 +54,21 @@ To learn more about Split, contact hello@split.io, or get started with feature f Split has built and maintains SDKs for: +* .NET [Github](https://github.com/splitio/dotnet-client) [Docs](https://help.split.io/hc/en-us/articles/360020240172--NET-SDK) +* Android [Github](https://github.com/splitio/android-client) [Docs](https://help.split.io/hc/en-us/articles/360020343291-Android-SDK) +* Angular [Github](https://github.com/splitio/angular-sdk-plugin) [Docs](https://help.split.io/hc/en-us/articles/6495326064397-Angular-utilities) +* GO [Github](https://github.com/splitio/go-client) [Docs](https://help.split.io/hc/en-us/articles/360020093652-Go-SDK) +* iOS [Github](https://github.com/splitio/ios-client) [Docs](https://help.split.io/hc/en-us/articles/360020401491-iOS-SDK) * Java [Github](https://github.com/splitio/java-client) [Docs](https://help.split.io/hc/en-us/articles/360020405151-Java-SDK) -* Javascript [Github](https://github.com/splitio/javascript-client) [Docs](https://help.split.io/hc/en-us/articles/360020448791-JavaScript-SDK) +* JavaScript [Github](https://github.com/splitio/javascript-client) [Docs](https://help.split.io/hc/en-us/articles/360020448791-JavaScript-SDK) +* JavaScript for Browser [Github](https://github.com/splitio/javascript-browser-client) [Docs](https://help.split.io/hc/en-us/articles/360058730852-Browser-SDK) * Node [Github](https://github.com/splitio/javascript-client) [Docs](https://help.split.io/hc/en-us/articles/360020564931-Node-js-SDK) -* .NET [Github](https://github.com/splitio/dotnet-client) [Docs](https://help.split.io/hc/en-us/articles/360020240172--NET-SDK) -* Ruby [Github](https://github.com/splitio/ruby-client) [Docs](https://help.split.io/hc/en-us/articles/360020673251-Ruby-SDK) * PHP [Github](https://github.com/splitio/php-client) [Docs](https://help.split.io/hc/en-us/articles/360020350372-PHP-SDK) * Python [Github](https://github.com/splitio/python-client) [Docs](https://help.split.io/hc/en-us/articles/360020359652-Python-SDK) -* GO [Github](https://github.com/splitio/go-client) [Docs](https://help.split.io/hc/en-us/articles/360020093652-Go-SDK) -* Android [Github](https://github.com/splitio/android-client) [Docs](https://help.split.io/hc/en-us/articles/360020343291-Android-SDK) -* iOS [Github](https://github.com/splitio/ios-client) [Docs](https://help.split.io/hc/en-us/articles/360020401491-iOS-SDK) +* React [Github](https://github.com/splitio/react-client) [Docs](https://help.split.io/hc/en-us/articles/360038825091-React-SDK) +* React Native [Github](https://github.com/splitio/react-native-client) [Docs](https://help.split.io/hc/en-us/articles/4406066357901-React-Native-SDK) +* Redux [Github](https://github.com/splitio/redux-client) [Docs](https://help.split.io/hc/en-us/articles/360038851551-Redux-SDK) +* Ruby [Github](https://github.com/splitio/ruby-client) [Docs](https://help.split.io/hc/en-us/articles/360020673251-Ruby-SDK) For a comprehensive list of open source projects visit our [Github page](https://github.com/splitio?utf8=%E2%9C%93&query=%20only%3Apublic%20). diff --git a/doc/source/flask_support.rst b/doc/source/flask_support.rst index 7e1abf74..9ed4b8b8 100644 --- a/doc/source/flask_support.rst +++ b/doc/source/flask_support.rst @@ -37,4 +37,4 @@ This example assumes that the Split.io configuration is save in a file called `` When using the Redis client the update scripts need to be run periodically, otherwise there won't be any data available to the client. -As mentioned before, if the API key is set to ``'localhost'`` a localhost environment client is generated and no connections to Split.io are made as everything is read from ``.split`` file (you can read about this feature in the Localhost Environment section of the :doc:`/introduction`.) \ No newline at end of file +As mentioned before, if the API key is set to ``'localhost'`` a localhost environment client is generated and no connections to Split.io are made as everything is read from ``.split`` file (you can read about this feature in the Localhost Environment section of the :doc:`/introduction`.) diff --git a/doc/source/index.rst b/doc/source/index.rst index 8a61310b..249d74eb 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -20,4 +20,3 @@ Indices and tables * :ref:`genindex` * :ref:`modindex` * :ref:`search` - diff --git a/doc/source/introduction.rst b/doc/source/introduction.rst index bcad2158..a6df7a71 100644 --- a/doc/source/introduction.rst +++ b/doc/source/introduction.rst @@ -166,7 +166,7 @@ The client depends on the information for features and segments being updated ex The scripts are configured through a JSON settings file, like the following: :: { - "apiKey": "some-api-key", + "sdkKey": "some-sdk-key", "sdkApiBaseUrl": "https://sdk.split.io/api", "eventsApiBaseUrl": "https://events.split.io/api", "redisFactory": 'some.redis.factory', @@ -180,7 +180,7 @@ These are the possible configuration parameters: +------------------------+------+--------------------------------------------------------+-------------------------------+ | Key | Type | Description | Default | +========================+======+========================================================+===============================+ -| apiKey | str | A valid Split.io API key. | None | +| sdkKey | str | A valid Split.io SDK key. | None | +------------------------+------+--------------------------------------------------------+-------------------------------+ | sdkApiBaseUrl | str | The SDK API url base | "https://sdk.split.io/api" | +------------------------+------+--------------------------------------------------------+-------------------------------+ @@ -238,7 +238,7 @@ On the other hand, there is available a python script named ``splitio.bin.synchr The configuration file is a JSON file with the following fields: { - "apiKey": "YOUR_API_KEY", + "sdkKey": "YOUR_SDK_KEY", "redisHost": "REDIS_DNS_OR_IP", "redisPort": 6379, "redisDb": 0 @@ -274,7 +274,7 @@ In order to support Redis' Sentinel host discovery, you need to provide a custom Afterwards you tell the client to use this factory using the config file: :: { - "apiKey": "some-api-key", + "sdkKey": "some-sdk-key", "sdkApiBaseUrl": "https://sdk.split.io/api", "eventsApiBaseUrl": "https://events.split.io/api", "redisFactory": 'redis_config.my_redis_factory' diff --git a/setup.cfg b/setup.cfg index 164be372..1fa09f42 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,10 @@ universal = 1 [metadata] -description-file = README.md +name = splitio_client +description = This SDK is designed to work with Split, the platform for controlled rollouts, which serves features to your users via a Split feature flag to manage your complete customer experience. +long_description = file: README.md +long_description_content_type = text/markdown [flake8] max-line-length=100 @@ -12,7 +15,6 @@ exclude=tests/* test=pytest [tool:pytest] -ignore_glob=./splitio/_OLD/* addopts = --verbose --cov=splitio --cov-report xml python_classes=*Tests diff --git a/setup.py b/setup.py index 975daf13..e2b4c74a 100644 --- a/setup.py +++ b/setup.py @@ -7,18 +7,26 @@ TESTS_REQUIRES = [ 'flake8', 'pytest==7.0.1', - 'pytest-mock>=3.5.1', - 'coverage', - 'pytest-cov', - 'importlib-metadata==4.2', + 'pytest-mock==3.11.1', + 'coverage==7.0.0', + 'pytest-cov==4.1.0', + 'importlib-metadata==6.7', 'tomli==1.2.3', + 'iniconfig==1.1.1', + 'attrs==22.1.0', + 'pytest-asyncio==0.21.0', + 'aiohttp>=3.8.4', + 'aiofiles>=23.1.0', + 'requests-kerberos>=0.15.0', + 'urllib3==2.0.7' ] INSTALL_REQUIRES = [ - 'requests>=2.9.1', - 'pyyaml>=5.4', + 'requests', + 'pyyaml', 'docopt>=0.6.2', 'enum34;python_version<"3.4"', + 'bloom-filter2>=2.0.0' ] with open(path.join(path.abspath(path.dirname(__file__)), 'splitio', 'version.py')) as f: @@ -37,11 +45,13 @@ tests_require=TESTS_REQUIRES, extras_require={ 'test': TESTS_REQUIRES, - 'redis': ['redis>=2.10.5'], + 'redis': ['redis>=2.10.5,<7.0.0'], 'uwsgi': ['uwsgi>=2.0.0'], 'cpphash': ['mmh3cffi==0.2.1'], + 'asyncio': ['aiohttp>=3.8.4', 'aiofiles>=23.1.0'], + 'kerberos': ['requests-kerberos>=0.15.0'] }, - setup_requires=['pytest-runner'], + setup_requires=['pytest-runner', 'pluggy==1.0.0;python_version<"3.8"'], classifiers=[ 'Environment :: Console', 'Intended Audience :: Developers', @@ -50,5 +60,5 @@ 'Programming Language :: Python :: 3', 'Topic :: Software Development :: Libraries' ], - packages=find_packages() + packages=find_packages(exclude=('tests', 'tests.*')) ) diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 00000000..009f4fd7 --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,10 @@ +sonar.projectName=python-client +sonar.projectKey=python-client +sonar.python.version=3.6 +sonar.sources=splitio +sonar.tests=tests +sonar.text.excluded.file.suffixes=.csv +sonar.python.coverage.reportPaths=coverage.xml +sonar.coverage.exclusions=**/__init__.py +sonar.links.ci=https://github.com/splitio/python-client +sonar.links.scm=https://github.com/splitio/python-client/actions diff --git a/splitio/__init__.py b/splitio/__init__.py index aced4602..e9c9302b 100644 --- a/splitio/__init__.py +++ b/splitio/__init__.py @@ -1,3 +1,3 @@ -from splitio.client.factory import get_factory +from splitio.client.factory import get_factory, get_factory_async from splitio.client.key import Key from splitio.version import __version__ diff --git a/splitio/api/__init__.py b/splitio/api/__init__.py index 33f1e588..be820f14 100644 --- a/splitio/api/__init__.py +++ b/splitio/api/__init__.py @@ -13,3 +13,34 @@ def __init__(self, custom_message, status_code=None): def status_code(self): """Return HTTP status code.""" return self._status_code + +class APIUriException(APIException): + """Exception to raise when an API call fails due to 414 http error.""" + + def __init__(self, custom_message, status_code=None): + """Constructor.""" + APIException.__init__(self, custom_message, status_code) + +def headers_from_metadata(sdk_metadata, client_key=None): + """ + Generate a dict with headers required by data-recording API endpoints. + :param sdk_metadata: SDK Metadata object, generated at sdk initialization time. + :type sdk_metadata: splitio.client.util.SdkMetadata + :param client_key: client key. + :type client_key: str + :return: A dictionary with headers. + :rtype: dict + """ + + metadata = { + 'SplitSDKVersion': sdk_metadata.sdk_version, + 'SplitSDKMachineIP': sdk_metadata.instance_ip, + 'SplitSDKMachineName': sdk_metadata.instance_name + } if sdk_metadata.instance_ip != 'NA' and sdk_metadata.instance_ip != 'unknown' else { + 'SplitSDKVersion': sdk_metadata.sdk_version, + } + + if client_key is not None: + metadata['SplitSDKClientKey'] = client_key + + return metadata \ No newline at end of file diff --git a/splitio/api/auth.py b/splitio/api/auth.py index b55200c2..986ee31a 100644 --- a/splitio/api/auth.py +++ b/splitio/api/auth.py @@ -3,11 +3,13 @@ import logging import json -from splitio.api import APIException -from splitio.api.commons import headers_from_metadata +from splitio.api import APIException, headers_from_metadata +from splitio.api.commons import headers_from_metadata, record_telemetry +from splitio.spec import SPEC_VERSION +from splitio.util.time import get_current_epoch_time_ms from splitio.api.client import HttpClientException from splitio.models.token import from_raw - +from splitio.models.telemetry import HTTPExceptionsAndLatencies _LOGGER = logging.getLogger(__name__) @@ -15,20 +17,22 @@ class AuthAPI(object): # pylint: disable=too-few-public-methods """Class that uses an httpClient to communicate with the SDK Auth Service API.""" - def __init__(self, client, apikey, sdk_metadata): + def __init__(self, client, sdk_key, sdk_metadata, telemetry_runtime_producer): """ Class constructor. :param client: HTTP Client responsble for issuing calls to the backend. :type client: HttpClient - :param apikey: User apikey token. - :type apikey: string + :param sdk_key: User sdk key. + :type sdk_key: string :param sdk_metadata: SDK version & machine name & IP. :type sdk_metadata: splitio.client.util.SdkMetadata """ self._client = client - self._apikey = apikey + self._sdk_key = sdk_key self._metadata = headers_from_metadata(sdk_metadata) + self._telemetry_runtime_producer = telemetry_runtime_producer + self._client.set_telemetry_data(HTTPExceptionsAndLatencies.TOKEN, self._telemetry_runtime_producer) def authenticate(self): """ @@ -40,14 +44,64 @@ def authenticate(self): try: response = self._client.get( 'auth', - '/v2/auth', - self._apikey, - extra_headers=self._metadata + 'v2/auth?s=' + SPEC_VERSION, + self._sdk_key, + extra_headers=self._metadata, + ) + if 200 <= response.status_code < 300: + payload = json.loads(response.body) + return from_raw(payload) + + else: + if (response.status_code >= 400 and response.status_code < 500): + self._telemetry_runtime_producer.record_auth_rejections() + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + _LOGGER.error('Exception raised while authenticating') + _LOGGER.debug('Exception information: ', exc_info=True) + raise APIException('Could not perform authentication.') from exc + +class AuthAPIAsync(object): # pylint: disable=too-few-public-methods + """Async Class that uses an httpClient to communicate with the SDK Auth Service API.""" + + def __init__(self, client, sdk_key, sdk_metadata, telemetry_runtime_producer): + """ + Class constructor. + + :param client: HTTP Client responsble for issuing calls to the backend. + :type client: HttpClient + :param sdk_key: User sdk key. + :type sdk_key: string + :param sdk_metadata: SDK version & machine name & IP. + :type sdk_metadata: splitio.client.util.SdkMetadata + """ + self._client = client + self._sdk_key = sdk_key + self._metadata = headers_from_metadata(sdk_metadata) + self._telemetry_runtime_producer = telemetry_runtime_producer + self._client.set_telemetry_data(HTTPExceptionsAndLatencies.TOKEN, self._telemetry_runtime_producer) + + async def authenticate(self): + """ + Perform authentication. + + :return: Json representation of an authentication. + :rtype: splitio.models.token.Token + """ + try: + response = await self._client.get( + 'auth', + 'v2/auth?s=' + SPEC_VERSION, + self._sdk_key, + extra_headers=self._metadata, ) if 200 <= response.status_code < 300: payload = json.loads(response.body) return from_raw(payload) + else: + if (response.status_code >= 400 and response.status_code < 500): + await self._telemetry_runtime_producer.record_auth_rejections() raise APIException(response.body, response.status_code) except HttpClientException as exc: _LOGGER.error('Exception raised while authenticating') diff --git a/splitio/api/client.py b/splitio/api/client.py index 505547e5..c9032e0e 100644 --- a/splitio/api/client.py +++ b/splitio/api/client.py @@ -1,10 +1,63 @@ """Synchronous HTTP Client for split API.""" from collections import namedtuple - import requests +import urllib +import abc +import logging +import json +import threading +from urllib3.util import parse_url + +from splitio.optional.loaders import HTTPKerberosAuth, OPTIONAL +from splitio.client.config import AuthenticateScheme +from splitio.optional.loaders import aiohttp +from splitio.util.time import get_current_epoch_time_ms + +SDK_URL = 'https://sdk.split.io/api' +EVENTS_URL = 'https://events.split.io/api' +AUTH_URL = 'https://auth.split.io/api' +TELEMETRY_URL = 'https://telemetry.split.io/api' + +_LOGGER = logging.getLogger(__name__) +_EXC_MSG = '{source} library is throwing exceptions' + +HttpResponse = namedtuple('HttpResponse', ['status_code', 'body', 'headers']) + +def _build_url(server, path, urls): + """ + Build URL according to server specified. + + :param server: Server for whith the request is being made. + :type server: str + :param path: URL path to be appended to base host. + :type path: str -HttpResponse = namedtuple('HttpResponse', ['status_code', 'body']) + :return: A fully qualified URL. + :rtype: str + """ + url = urls[server] + url += '/' if urls[server][:-1] != '/' else '' + return urllib.parse.urljoin(url, path) +def _construct_urls(sdk_url=None, events_url=None, auth_url=None, telemetry_url=None): + return { + 'sdk': sdk_url if sdk_url is not None else SDK_URL, + 'events': events_url if events_url is not None else EVENTS_URL, + 'auth': auth_url if auth_url is not None else AUTH_URL, + 'telemetry': telemetry_url if telemetry_url is not None else TELEMETRY_URL, + } + +def _build_basic_headers(sdk_key): + """ + Build basic headers with auth. + + :param sdk_key: API token used to identify backend calls. + :type sdk_key: str + """ + return { + 'Content-Type': 'application/json', + 'Authorization': "Bearer %s" % sdk_key + } class HttpClientException(Exception): """HTTP Client exception.""" @@ -18,15 +71,28 @@ def __init__(self, message): """ Exception.__init__(self, message) +class HTTPAdapterWithProxyKerberosAuth(requests.adapters.HTTPAdapter): + """HTTPAdapter override for Kerberos Proxy auth""" -class HttpClient(object): - """HttpClient wrapper.""" + def __init__(self, principal=None, password=None): + requests.adapters.HTTPAdapter.__init__(self) + self._principal = principal + self._password = password - SDK_URL = 'https://sdk.split.io/api' - EVENTS_URL = 'https://events.split.io/api' - AUTH_URL = 'https://auth.split.io/api' + def proxy_headers(self, proxy): + headers = {} + if self._principal is not None: + auth = HTTPKerberosAuth(principal=self._principal, password=self._password) + else: + auth = HTTPKerberosAuth() + negotiate_details = auth.generate_request_header(None, parse_url(proxy).host, is_preemptive=True) + headers['Proxy-Authorization'] = negotiate_details + return headers - def __init__(self, timeout=None, sdk_url=None, events_url=None, auth_url=None): +class HttpClientBase(object, metaclass=abc.ABCMeta): + """HttpClient wrapper template.""" + + def __init__(self, timeout=None, sdk_url=None, events_url=None, auth_url=None, telemetry_url=None): """ Class constructor. @@ -38,42 +104,81 @@ def __init__(self, timeout=None, sdk_url=None, events_url=None, auth_url=None): :type events_url: str :param auth_url: Optional alternative auth URL. :type auth_url: str + :param telemetry_url: Optional alternative telemetry URL. + :type telemetry_url: str + """ + _LOGGER.debug("Initializing httpclient") + self._timeout = timeout/1000 if timeout else None # Convert ms to seconds. + self._urls = _construct_urls(sdk_url, events_url, auth_url, telemetry_url) + + @abc.abstractmethod + def get(self, server, path, apikey): + """http get request""" + + @abc.abstractmethod + def post(self, server, path, apikey): + """http post request""" + + def set_telemetry_data(self, metric_name, telemetry_runtime_producer): """ - self._timeout = timeout/1000 if timeout else None # Convert ms to seconds. - self._urls = { - 'sdk': sdk_url if sdk_url is not None else self.SDK_URL, - 'events': events_url if events_url is not None else self.EVENTS_URL, - 'auth': auth_url if auth_url is not None else self.AUTH_URL, - } + Set the data needed for telemetry call - def _build_url(self, server, path): + :param metric_name: metric name for telemetry + :type metric_name: str + + :param telemetry_runtime_producer: telemetry recording instance + :type telemetry_runtime_producer: splitio.engine.telemetry.TelemetryRuntimeProducer """ - Build URL according to server specified. + self._telemetry_runtime_producer = telemetry_runtime_producer + self._metric_name = metric_name - :param server: Server for whith the request is being made. - :type server: str - :param path: URL path to be appended to base host. - :type path: str + def is_sdk_endpoint_overridden(self): + return self._urls['sdk'] != SDK_URL + + def _get_headers(self, extra_headers, sdk_key): + headers = _build_basic_headers(sdk_key) + if extra_headers is not None: + headers.update(extra_headers) + return headers - :return: A fully qualified URL. - :rtype: str + def _record_telemetry(self, status_code, elapsed): """ - return self._urls[server] + path + Record Telemetry info - @staticmethod - def _build_basic_headers(apikey): + :param status_code: http request status code + :type status_code: int + + :param elapsed: response time elapsed. + :type status_code: int """ - Build basic headers with auth. + self._telemetry_runtime_producer.record_sync_latency(self._metric_name, elapsed) + if 200 <= status_code < 300: + self._telemetry_runtime_producer.record_successful_sync(self._metric_name, get_current_epoch_time_ms()) + return - :param apikey: API token used to identify backend calls. - :type apikey: str + self._telemetry_runtime_producer.record_sync_error(self._metric_name, status_code) + +class HttpClient(HttpClientBase): + """HttpClient wrapper.""" + + def __init__(self, timeout=None, sdk_url=None, events_url=None, auth_url=None, telemetry_url=None): """ - return { - 'Content-Type': 'application/json', - 'Authorization': "Bearer %s" % apikey - } + Class constructor. - def get(self, server, path, apikey, query=None, extra_headers=None): # pylint: disable=too-many-arguments + :param timeout: How many milliseconds to wait until the server responds. + :type timeout: int + :param sdk_url: Optional alternative sdk URL. + :type sdk_url: str + :param events_url: Optional alternative events URL. + :type events_url: str + :param auth_url: Optional alternative auth URL. + :type auth_url: str + :param telemetry_url: Optional alternative telemetry URL. + :type telemetry_url: str + """ + HttpClientBase.__init__(self, timeout, sdk_url, events_url, auth_url, telemetry_url) + + def get(self, server, path, sdk_key, query=None, extra_headers=None): # pylint: disable=too-many-arguments """ Issue a get request. @@ -81,8 +186,8 @@ def get(self, server, path, apikey, query=None, extra_headers=None): # pylint: :typee server: str :param path: path to append to the host url. :type path: str - :param apikey: api token. - :type apikey: str + :param sdk_key: sdk key. + :type sdk_key: str :param query: Query string passed as dictionary. :type query: dict :param extra_headers: key/value pairs of possible extra headers. @@ -91,22 +196,25 @@ def get(self, server, path, apikey, query=None, extra_headers=None): # pylint: :return: Tuple of status_code & response text :rtype: HttpResponse """ - headers = self._build_basic_headers(apikey) - if extra_headers is not None: - headers.update(extra_headers) - + start = get_current_epoch_time_ms() try: response = requests.get( - self._build_url(server, path), + _build_url(server, path, self._urls), params=query, - headers=headers, + headers=self._get_headers(extra_headers, sdk_key), timeout=self._timeout ) - return HttpResponse(response.status_code, response.text) - except Exception as exc: # pylint: disable=broad-except - raise HttpClientException('requests library is throwing exceptions') from exc + self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start) + return HttpResponse(response.status_code, response.text, response.headers) + + except requests.exceptions.ChunkedEncodingError as exc: + _LOGGER.error("IncompleteRead exception detected: %s", exc) + return HttpResponse(400, "", {}) + + except Exception as exc: # pylint: disable=broad-except + raise HttpClientException(_EXC_MSG.format(source='request')) from exc - def post(self, server, path, apikey, body, query=None, extra_headers=None): # pylint: disable=too-many-arguments + def post(self, server, path, sdk_key, body, query=None, extra_headers=None): # pylint: disable=too-many-arguments """ Issue a POST request. @@ -114,8 +222,8 @@ def post(self, server, path, apikey, body, query=None, extra_headers=None): # p :typee server: str :param path: path to append to the host url. :type path: str - :param apikey: api token. - :type apikey: str + :param sdk_key: sdk key. + :type sdk_key: str :param body: body sent in the request. :type body: str :param query: Query string passed as dictionary. @@ -126,19 +234,332 @@ def post(self, server, path, apikey, body, query=None, extra_headers=None): # p :return: Tuple of status_code & response text :rtype: HttpResponse """ - headers = self._build_basic_headers(apikey) - - if extra_headers is not None: - headers.update(extra_headers) - + start = get_current_epoch_time_ms() try: response = requests.post( - self._build_url(server, path), + _build_url(server, path, self._urls), json=body, params=query, - headers=headers, - timeout=self._timeout + headers=self._get_headers(extra_headers, sdk_key), + timeout=self._timeout, ) - return HttpResponse(response.status_code, response.text) + self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start) + return HttpResponse(response.status_code, response.text, response.headers) except Exception as exc: # pylint: disable=broad-except - raise HttpClientException('requests library is throwing exceptions') from exc + raise HttpClientException(_EXC_MSG.format(source='request')) from exc + +class HttpClientAsync(HttpClientBase): + """HttpClientAsync wrapper.""" + + def __init__(self, timeout=None, sdk_url=None, events_url=None, auth_url=None, telemetry_url=None): + """ + Class constructor. + :param timeout: How many milliseconds to wait until the server responds. + :type timeout: int + :param sdk_url: Optional alternative sdk URL. + :type sdk_url: str + :param events_url: Optional alternative events URL. + :type events_url: str + :param auth_url: Optional alternative auth URL. + :type auth_url: str + :param telemetry_url: Optional alternative telemetry URL. + :type telemetry_url: str + """ + HttpClientBase.__init__(self, timeout, sdk_url, events_url, auth_url, telemetry_url) + self._session = aiohttp.ClientSession() + + async def get(self, server, path, apikey, query=None, extra_headers=None): # pylint: disable=too-many-arguments + """ + Issue a get request. + :param server: Whether the request is for SDK server, Events server or Auth server. + :typee server: str + :param path: path to append to the host url. + :type path: str + :param apikey: api token. + :type apikey: str + :param query: Query string passed as dictionary. + :type query: dict + :param extra_headers: key/value pairs of possible extra headers. + :type extra_headers: dict + :return: Tuple of status_code & response text + :rtype: HttpResponse + """ + start = get_current_epoch_time_ms() + headers = self._get_headers(extra_headers, apikey) + try: + url = _build_url(server, path, self._urls) + _LOGGER.debug("GET request: %s", url) + _LOGGER.debug("query params: %s", query) + _LOGGER.debug("headers: %s", headers) + async with self._session.get( + url, + params=query, + headers=headers, + timeout=self._timeout + ) as response: + body = await response.text() + _LOGGER.debug("Response:") + _LOGGER.debug(response) + _LOGGER.debug(body) + await self._record_telemetry(response.status, get_current_epoch_time_ms() - start) + return HttpResponse(response.status, body, response.headers) + + except aiohttp.ClientPayloadError as exc: + _LOGGER.error("ContentLengthError exception detected: %s", exc) + return HttpResponse(400, "", {}) + + except aiohttp.ClientError as exc: # pylint: disable=broad-except + raise HttpClientException(_EXC_MSG.format(source='aiohttp')) from exc + + async def post(self, server, path, apikey, body, query=None, extra_headers=None): # pylint: disable=too-many-arguments + """ + Issue a POST request. + :param server: Whether the request is for SDK server or Events server. + :typee server: str + :param path: path to append to the host url. + :type path: str + :param apikey: api token. + :type apikey: str + :param body: body sent in the request. + :type body: str + :param query: Query string passed as dictionary. + :type query: dict + :param extra_headers: key/value pairs of possible extra headers. + :type extra_headers: dict + :return: Tuple of status_code & response text + :rtype: HttpResponse + """ + headers = self._get_headers(extra_headers, apikey) + start = get_current_epoch_time_ms() + try: + headers['Accept-Encoding'] = 'gzip' + _LOGGER.debug("POST request: %s", _build_url(server, path, self._urls)) + _LOGGER.debug("query params: %s", query) + _LOGGER.debug("headers: %s", headers) + _LOGGER.debug("payload: ") + _LOGGER.debug(str(json.dumps(body)).encode('utf-8')) + async with self._session.post( + _build_url(server, path, self._urls), + params=query, + headers=headers, + json=body, + timeout=self._timeout + ) as response: + body = await response.text() + _LOGGER.debug("Response:") + _LOGGER.debug(response) + _LOGGER.debug(body) + await self._record_telemetry(response.status, get_current_epoch_time_ms() - start) + return HttpResponse(response.status, body, response.headers) + + except aiohttp.ClientError as exc: # pylint: disable=broad-except + raise HttpClientException(_EXC_MSG.format(source='aiohttp')) from exc + + async def _record_telemetry(self, status_code, elapsed): + """ + Record Telemetry info + + :param status_code: http request status code + :type status_code: int + + :param elapsed: response time elapsed. + :type status_code: int + """ + await self._telemetry_runtime_producer.record_sync_latency(self._metric_name, elapsed) + if 200 <= status_code < 300: + await self._telemetry_runtime_producer.record_successful_sync(self._metric_name, get_current_epoch_time_ms()) + return + + await self._telemetry_runtime_producer.record_sync_error(self._metric_name, status_code) + + async def close_session(self): + if not self._session.closed: + await self._session.close() + +class HttpClientKerberos(HttpClientBase): + """HttpClient wrapper.""" + + def __init__(self, timeout=None, sdk_url=None, events_url=None, auth_url=None, telemetry_url=None, authentication_scheme=None, authentication_params=None): + """ + Class constructor. + + :param timeout: How many milliseconds to wait until the server responds. + :type timeout: int + :param sdk_url: Optional alternative sdk URL. + :type sdk_url: str + :param events_url: Optional alternative events URL. + :type events_url: str + :param auth_url: Optional alternative auth URL. + :type auth_url: str + :param telemetry_url: Optional alternative telemetry URL. + :type telemetry_url: str + :param authentication_scheme: Optional authentication scheme to use. + :type authentication_scheme: splitio.client.config.AuthenticateScheme + :param authentication_params: Optional authentication username and password to use. + :type authentication_params: [str, str] + """ + _LOGGER.debug("Initializing httpclient for Kerberos auth") + self._timeout = timeout/1000 if timeout else None # Convert ms to seconds. + self._urls = _construct_urls(sdk_url, events_url, auth_url, telemetry_url) + self._authentication_scheme = authentication_scheme + self._authentication_params = authentication_params + self._lock = threading.RLock() + self._sessions = {'sdk': requests.Session(), + 'events': requests.Session(), + 'auth': requests.Session(), + 'telemetry': requests.Session()} + self._set_authentication() + + def get(self, server, path, sdk_key, query=None, extra_headers=None): # pylint: disable=too-many-arguments + """ + Issue a get request. + :param server: Whether the request is for SDK server, Events server or Auth server. + :typee server: str + :param path: path to append to the host url. + :type path: str + :param sdk_key: sdk key. + :type sdk_key: str + :param query: Query string passed as dictionary. + :type query: dict + :param extra_headers: key/value pairs of possible extra headers. + :type extra_headers: dict + + :return: Tuple of status_code & response text + :rtype: HttpResponse + """ + with self._lock: + start = get_current_epoch_time_ms() + try: + return self._do_get(server, path, sdk_key, query, extra_headers, start) + + except requests.exceptions.ProxyError as exc: + _LOGGER.debug("Proxy Exception caught, resetting the http session") + self._sessions[server].close() + self._sessions[server] = requests.Session() + self._set_authentication(server_name=server) + try: + return self._do_get(server, path, sdk_key, query, extra_headers, start) + + except Exception as exc: + raise HttpClientException(_EXC_MSG.format(source='request')) from exc + + except Exception as exc: # pylint: disable=broad-except + raise HttpClientException(_EXC_MSG.format(source='request')) from exc + + def _do_get(self, server, path, sdk_key, query, extra_headers, start): + """ + Issue a get request. + :param server: Whether the request is for SDK server, Events server or Auth server. + :typee server: str + :param path: path to append to the host url. + :type path: str + :param sdk_key: sdk key. + :type sdk_key: str + :param query: Query string passed as dictionary. + :type query: dict + :param extra_headers: key/value pairs of possible extra headers. + :type extra_headers: dict + + :return: Tuple of status_code & response text + :rtype: HttpResponse + """ + with self._sessions[server].get( + _build_url(server, path, self._urls), + headers=self._get_headers(extra_headers, sdk_key), + params=query, + timeout=self._timeout + ) as response: + self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start) + return HttpResponse(response.status_code, response.text, response.headers) + + def post(self, server, path, sdk_key, body, query=None, extra_headers=None): # pylint: disable=too-many-arguments + """ + Issue a POST request. + + :param server: Whether the request is for SDK server or Events server. + :typee server: str + :param path: path to append to the host url. + :type path: str + :param sdk_key: sdk key. + :type sdk_key: str + :param body: body sent in the request. + :type body: str + :param query: Query string passed as dictionary. + :type query: dict + :param extra_headers: key/value pairs of possible extra headers. + :type extra_headers: dict + + :return: Tuple of status_code & response text + :rtype: HttpResponse + """ + with self._lock: + start = get_current_epoch_time_ms() + try: + return self._do_post(server, path, sdk_key, query, extra_headers, body, start) + + except requests.exceptions.ProxyError as exc: + _LOGGER.debug("Proxy Exception caught, resetting the http session") + self._sessions[server].close() + self._sessions[server] = requests.Session() + self._set_authentication(server_name=server) + try: + return self._do_post(server, path, sdk_key, query, extra_headers, body, start) + + except Exception as exc: + raise HttpClientException(_EXC_MSG.format(source='request')) from exc + + except Exception as exc: # pylint: disable=broad-except + raise HttpClientException(_EXC_MSG.format(source='request')) from exc + + def _do_post(self, server, path, sdk_key, query, extra_headers, body, start): + """ + Issue a POST request. + + :param server: Whether the request is for SDK server or Events server. + :typee server: str + :param path: path to append to the host url. + :type path: str + :param sdk_key: sdk key. + :type sdk_key: str + :param body: body sent in the request. + :type body: str + :param query: Query string passed as dictionary. + :type query: dict + :param extra_headers: key/value pairs of possible extra headers. + :type extra_headers: dict + + :return: Tuple of status_code & response text + :rtype: HttpResponse + """ + with self._sessions[server].post( + _build_url(server, path, self._urls), + params=query, + headers=self._get_headers(extra_headers, sdk_key), + json=body, + timeout=self._timeout, + ) as response: + self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start) + return HttpResponse(response.status_code, response.text, response.headers) + + def _set_authentication(self, server_name=None): + """ + Set the authentication for all self._sessions variables based on authentication scheme. + + :param server: If set, will only add the auth for its session variable, otherwise will set all sessions. + :typee server: str + """ + for server in ['sdk', 'events', 'auth', 'telemetry']: + if server_name is not None and server_name != server: + continue + if self._authentication_scheme == AuthenticateScheme.KERBEROS_SPNEGO: + _LOGGER.debug("Using Kerberos Spnego Authentication") + if self._authentication_params != [None, None]: + self._sessions[server].auth = HTTPKerberosAuth(principal=self._authentication_params[0], password=self._authentication_params[1], mutual_authentication=OPTIONAL) + else: + self._sessions[server].auth = HTTPKerberosAuth(mutual_authentication=OPTIONAL) + elif self._authentication_scheme == AuthenticateScheme.KERBEROS_PROXY: + _LOGGER.debug("Using Kerberos Proxy Authentication") + if self._authentication_params != [None, None]: + self._sessions[server].mount('https://', HTTPAdapterWithProxyKerberosAuth(principal=self._authentication_params[0], password=self._authentication_params[1])) + else: + self._sessions[server].mount('https://', HTTPAdapterWithProxyKerberosAuth()) diff --git a/splitio/api/commons.py b/splitio/api/commons.py index 53019427..9dda1ee0 100644 --- a/splitio/api/commons.py +++ b/splitio/api/commons.py @@ -1,10 +1,10 @@ """Commons module.""" - +from splitio.util.time import get_current_epoch_time_ms +from splitio.spec import SPEC_VERSION _CACHE_CONTROL = 'Cache-Control' _CACHE_CONTROL_NO_CACHE = 'no-cache' - def headers_from_metadata(sdk_metadata, client_key=None): """ Generate a dict with headers required by data-recording API endpoints. @@ -32,11 +32,32 @@ def headers_from_metadata(sdk_metadata, client_key=None): return metadata +def record_telemetry(status_code, elapsed, metric_name, telemetry_runtime_producer): + """ + Record Telemetry info + + :param status_code: http request status code + :type status_code: int + + :param elapsed: response time elapsed. + :type status_code: int + + :param metric_name: metric name for telemetry + :type metric_name: str + + :param telemetry_runtime_producer: telemetry recording instance + :type telemetry_runtime_producer: splitio.engine.telemetry.TelemetryRuntimeProducer + """ + telemetry_runtime_producer.record_sync_latency(metric_name, elapsed) + if 200 <= status_code < 300: + telemetry_runtime_producer.record_successful_sync(metric_name, get_current_epoch_time_ms()) + return + telemetry_runtime_producer.record_sync_error(metric_name, status_code) class FetchOptions(object): """Fetch Options object.""" - def __init__(self, cache_control_headers=False, change_number=None): + def __init__(self, cache_control_headers=False, change_number=None, rbs_change_number=None, sets=None, spec=SPEC_VERSION): """ Class constructor. @@ -45,9 +66,15 @@ def __init__(self, cache_control_headers=False, change_number=None): :param change_number: ChangeNumber to use for bypassing CDN in request. :type change_number: int + + :param sets: list of flag sets + :type sets: list """ self._cache_control_headers = cache_control_headers self._change_number = change_number + self._rbs_change_number = rbs_change_number + self._sets = sets + self._spec = spec @property def cache_control_headers(self): @@ -59,16 +86,42 @@ def change_number(self): """Return change number.""" return self._change_number + @property + def rbs_change_number(self): + """Return change number.""" + return self._rbs_change_number + + @property + def sets(self): + """Return sets.""" + return self._sets + + @property + def spec(self): + """Return sets.""" + return self._spec + def __eq__(self, other): """Match between other options.""" if self._cache_control_headers != other._cache_control_headers: return False + if self._change_number != other._change_number: return False + + if self._rbs_change_number != other._rbs_change_number: + return False + + if self._sets != other._sets: + return False + + if self._spec != other._spec: + return False + return True -def build_fetch(change_number, fetch_options, metadata): +def build_fetch(change_number, fetch_options, metadata, rbs_change_number=None): """ Build fetch with new flags if that is the case. @@ -81,15 +134,24 @@ def build_fetch(change_number, fetch_options, metadata): :param metadata: Metadata Headers. :type metadata: dict + :param rbs_change_number: Last known timestamp of a rule based segment modification. + :type rbs_change_number: int + :return: Objects for fetch :rtype: dict, dict """ - query = {'since': change_number} + query = {'s': fetch_options.spec} if fetch_options.spec is not None else {} + query['since'] = change_number + if rbs_change_number is not None: + query['rbSince'] = rbs_change_number extra_headers = metadata if fetch_options is None: return query, extra_headers + if fetch_options.cache_control_headers: extra_headers[_CACHE_CONTROL] = _CACHE_CONTROL_NO_CACHE + if fetch_options.sets is not None: + query['sets'] = fetch_options.sets if fetch_options.change_number is not None: query['till'] = fetch_options.change_number - return query, extra_headers + return query, extra_headers \ No newline at end of file diff --git a/splitio/api/events.py b/splitio/api/events.py index b8ddda36..16beeddc 100644 --- a/splitio/api/events.py +++ b/splitio/api/events.py @@ -1,31 +1,16 @@ """Events API module.""" import logging -from splitio.api import APIException +from splitio.api import APIException, headers_from_metadata from splitio.api.client import HttpClientException -from splitio.api.commons import headers_from_metadata +from splitio.models.telemetry import HTTPExceptionsAndLatencies _LOGGER = logging.getLogger(__name__) -class EventsAPI(object): # pylint: disable=too-few-public-methods - """Class that uses an httpClient to communicate with the events API.""" - - def __init__(self, http_client, apikey, sdk_metadata): - """ - Class constructor. - - :param http_client: HTTP Client responsble for issuing calls to the backend. - :type http_client: HttpClient - :param apikey: User apikey token. - :type apikey: string - :param sdk_metadata: SDK version & machine name & IP. - :type sdk_metadata: splitio.client.util.SdkMetadata - """ - self._client = http_client - self._apikey = apikey - self._metadata = headers_from_metadata(sdk_metadata) +class EventsAPIBase(object): # pylint: disable=too-few-public-methods + """Base Class that uses an httpClient to communicate with the events API.""" @staticmethod def _build_bulk(events): @@ -50,6 +35,27 @@ def _build_bulk(events): for event in events ] + +class EventsAPI(EventsAPIBase): # pylint: disable=too-few-public-methods + """Class that uses an httpClient to communicate with the events API.""" + + def __init__(self, http_client, sdk_key, sdk_metadata, telemetry_runtime_producer): + """ + Class constructor. + + :param http_client: HTTP Client responsble for issuing calls to the backend. + :type http_client: HttpClient + :param sdk_key: sdk key. + :type sdk_key: string + :param sdk_metadata: SDK version & machine name & IP. + :type sdk_metadata: splitio.client.util.SdkMetadata + """ + self._client = http_client + self._sdk_key = sdk_key + self._metadata = headers_from_metadata(sdk_metadata) + self._telemetry_runtime_producer = telemetry_runtime_producer + self._client.set_telemetry_data(HTTPExceptionsAndLatencies.EVENT, self._telemetry_runtime_producer) + def flush_events(self, events): """ Send events to the backend. @@ -64,10 +70,56 @@ def flush_events(self, events): try: response = self._client.post( 'events', - '/events/bulk', - self._apikey, + 'events/bulk', + self._sdk_key, + body=bulk, + extra_headers=self._metadata, + ) + if not 200 <= response.status_code < 300: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + _LOGGER.error('Error posting events because an exception was raised by the HTTPClient') + _LOGGER.debug('Error: ', exc_info=True) + raise APIException('Events not flushed properly.') from exc + +class EventsAPIAsync(EventsAPIBase): # pylint: disable=too-few-public-methods + """Async Class that uses an httpClient to communicate with the events API.""" + + def __init__(self, http_client, sdk_key, sdk_metadata, telemetry_runtime_producer): + """ + Class constructor. + + :param http_client: HTTP Client responsble for issuing calls to the backend. + :type http_client: HttpClient + :param sdk_key: sdk key. + :type sdk_key: string + :param sdk_metadata: SDK version & machine name & IP. + :type sdk_metadata: splitio.client.util.SdkMetadata + """ + self._client = http_client + self._sdk_key = sdk_key + self._metadata = headers_from_metadata(sdk_metadata) + self._telemetry_runtime_producer = telemetry_runtime_producer + self._client.set_telemetry_data(HTTPExceptionsAndLatencies.EVENT, self._telemetry_runtime_producer) + + async def flush_events(self, events): + """ + Send events to the backend. + + :param events: Events bulk + :type events: list + + :return: True if flush was successful. False otherwise + :rtype: bool + """ + bulk = self._build_bulk(events) + try: + response = await self._client.post( + 'events', + 'events/bulk', + self._sdk_key, body=bulk, - extra_headers=self._metadata + extra_headers=self._metadata, ) if not 200 <= response.status_code < 300: raise APIException(response.body, response.status_code) diff --git a/splitio/api/impressions.py b/splitio/api/impressions.py index 02206a1e..da85691b 100644 --- a/splitio/api/impressions.py +++ b/splitio/api/impressions.py @@ -3,31 +3,17 @@ import logging from itertools import groupby -from splitio.api import APIException +from splitio.api import APIException, headers_from_metadata from splitio.api.client import HttpClientException -from splitio.api.commons import headers_from_metadata from splitio.engine.impressions import ImpressionsMode +from splitio.models.telemetry import HTTPExceptionsAndLatencies _LOGGER = logging.getLogger(__name__) -class ImpressionsAPI(object): # pylint: disable=too-few-public-methods - """Class that uses an httpClient to communicate with the impressions API.""" - - def __init__(self, client, apikey, sdk_metadata, mode=ImpressionsMode.OPTIMIZED): - """ - Class constructor. - - :param client: HTTP Client responsble for issuing calls to the backend. - :type client: HttpClient - :param apikey: User apikey token. - :type apikey: string - """ - self._client = client - self._apikey = apikey - self._metadata = headers_from_metadata(sdk_metadata) - self._metadata['SplitSDKImpressionsMode'] = mode.name +class ImpressionsAPIBase(object): # pylint: disable=too-few-public-methods + """Base Class that uses an httpClient to communicate with the impressions API.""" @staticmethod def _build_bulk(impressions): @@ -44,15 +30,7 @@ def _build_bulk(impressions): { 'f': test_name, 'i': [ - { - 'k': impression.matching_key, - 't': impression.treatment, - 'm': impression.time, - 'c': impression.change_number, - 'r': impression.label, - 'b': impression.bucketing_key, - 'pt': impression.previous_time - } + ImpressionsAPIBase._filter_out_null_prop(impression) for impression in imps ] } @@ -62,6 +40,30 @@ def _build_bulk(impressions): ) ] + @staticmethod + def _filter_out_null_prop(impression): + if impression.properties == None: + return { + 'k': impression.matching_key, + 't': impression.treatment, + 'm': impression.time, + 'c': impression.change_number, + 'r': impression.label, + 'b': impression.bucketing_key, + 'pt': impression.previous_time + } + + return { + 'k': impression.matching_key, + 't': impression.treatment, + 'm': impression.time, + 'c': impression.change_number, + 'r': impression.label, + 'b': impression.bucketing_key, + 'pt': impression.previous_time, + 'properties': impression.properties + } + @staticmethod def _build_counters(counters): """ @@ -83,6 +85,25 @@ def _build_counters(counters): ] } + +class ImpressionsAPI(ImpressionsAPIBase): # pylint: disable=too-few-public-methods + """Class that uses an httpClient to communicate with the impressions API.""" + + def __init__(self, client, sdk_key, sdk_metadata, telemetry_runtime_producer, mode=ImpressionsMode.OPTIMIZED): + """ + Class constructor. + + :param client: HTTP Client responsble for issuing calls to the backend. + :type client: HttpClient + :param sdk_key: sdk key. + :type sdk_key: string + """ + self._client = client + self._sdk_key = sdk_key + self._metadata = headers_from_metadata(sdk_metadata) + self._metadata['SplitSDKImpressionsMode'] = mode.name + self._telemetry_runtime_producer = telemetry_runtime_producer + def flush_impressions(self, impressions): """ Send impressions to the backend. @@ -91,13 +112,14 @@ def flush_impressions(self, impressions): :type impressions: list """ bulk = self._build_bulk(impressions) + self._client.set_telemetry_data(HTTPExceptionsAndLatencies.IMPRESSION, self._telemetry_runtime_producer) try: response = self._client.post( 'events', - '/testImpressions/bulk', - self._apikey, + 'testImpressions/bulk', + self._sdk_key, body=bulk, - extra_headers=self._metadata + extra_headers=self._metadata, ) if not 200 <= response.status_code < 300: raise APIException(response.body, response.status_code) @@ -116,13 +138,86 @@ def flush_counters(self, counters): :type impressions: list """ bulk = self._build_counters(counters) + self._client.set_telemetry_data(HTTPExceptionsAndLatencies.IMPRESSION_COUNT, self._telemetry_runtime_producer) try: response = self._client.post( 'events', - '/testImpressions/count', - self._apikey, + 'testImpressions/count', + self._sdk_key, + body=bulk, + extra_headers=self._metadata, + ) + if not 200 <= response.status_code < 300: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + _LOGGER.error( + 'Error posting impressions counters because an exception was raised by the ' + 'HTTPClient' + ) + _LOGGER.debug('Error: ', exc_info=True) + raise APIException('Impressions not flushed properly.') from exc + + +class ImpressionsAPIAsync(ImpressionsAPIBase): # pylint: disable=too-few-public-methods + """Async Class that uses an httpClient to communicate with the impressions API.""" + + def __init__(self, client, sdk_key, sdk_metadata, telemetry_runtime_producer, mode=ImpressionsMode.OPTIMIZED): + """ + Class constructor. + + :param client: HTTP Client responsble for issuing calls to the backend. + :type client: HttpClient + :param sdk_key: sdk key. + :type sdk_key: string + """ + self._client = client + self._sdk_key = sdk_key + self._metadata = headers_from_metadata(sdk_metadata) + self._metadata['SplitSDKImpressionsMode'] = mode.name + self._telemetry_runtime_producer = telemetry_runtime_producer + + async def flush_impressions(self, impressions): + """ + Send impressions to the backend. + + :param impressions: Impressions bulk + :type impressions: list + """ + bulk = self._build_bulk(impressions) + self._client.set_telemetry_data(HTTPExceptionsAndLatencies.IMPRESSION, self._telemetry_runtime_producer) + try: + response = await self._client.post( + 'events', + 'testImpressions/bulk', + self._sdk_key, + body=bulk, + extra_headers=self._metadata, + ) + if not 200 <= response.status_code < 300: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + _LOGGER.error( + 'Error posting impressions because an exception was raised by the HTTPClient' + ) + _LOGGER.debug('Error: ', exc_info=True) + raise APIException('Impressions not flushed properly.') from exc + + async def flush_counters(self, counters): + """ + Send impressions to the backend. + + :param impressions: Impressions bulk + :type impressions: list + """ + bulk = self._build_counters(counters) + self._client.set_telemetry_data(HTTPExceptionsAndLatencies.IMPRESSION_COUNT, self._telemetry_runtime_producer) + try: + response = await self._client.post( + 'events', + 'testImpressions/count', + self._sdk_key, body=bulk, - extra_headers=self._metadata + extra_headers=self._metadata, ) if not 200 <= response.status_code < 300: raise APIException(response.body, response.status_code) diff --git a/splitio/api/segments.py b/splitio/api/segments.py index ebc65a7e..aae33ac6 100644 --- a/splitio/api/segments.py +++ b/splitio/api/segments.py @@ -3,9 +3,10 @@ import json import logging -from splitio.api import APIException -from splitio.api.commons import headers_from_metadata, build_fetch +from splitio.api import APIException, headers_from_metadata +from splitio.api.commons import build_fetch from splitio.api.client import HttpClientException +from splitio.models.telemetry import HTTPExceptionsAndLatencies _LOGGER = logging.getLogger(__name__) @@ -14,21 +15,23 @@ class SegmentsAPI(object): # pylint: disable=too-few-public-methods """Class that uses an httpClient to communicate with the segments API.""" - def __init__(self, http_client, apikey, sdk_metadata): + def __init__(self, http_client, sdk_key, sdk_metadata, telemetry_runtime_producer): """ Class constructor. :param client: HTTP Client responsble for issuing calls to the backend. :type client: client.HttpClient - :param apikey: User apikey token. - :type apikey: string + :param sdk_key: User sdk_key token. + :type sdk_key: string :param sdk_metadata: SDK version & machine name & IP. :type sdk_metadata: splitio.client.util.SdkMetadata """ self._client = http_client - self._apikey = apikey + self._sdk_key = sdk_key self._metadata = headers_from_metadata(sdk_metadata) + self._telemetry_runtime_producer = telemetry_runtime_producer + self._client.set_telemetry_data(HTTPExceptionsAndLatencies.SEGMENT, self._telemetry_runtime_producer) def fetch_segment(self, segment_name, change_number, fetch_options): """ @@ -50,16 +53,74 @@ def fetch_segment(self, segment_name, change_number, fetch_options): query, extra_headers = build_fetch(change_number, fetch_options, self._metadata) response = self._client.get( 'sdk', - '/segmentChanges/{segment_name}'.format(segment_name=segment_name), - self._apikey, + 'segmentChanges/{segment_name}'.format(segment_name=segment_name), + self._sdk_key, extra_headers=extra_headers, query=query, ) + if 200 <= response.status_code < 300: + return json.loads(response.body) + + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + _LOGGER.error( + 'Error fetching %s because an exception was raised by the HTTPClient', + segment_name + ) + _LOGGER.debug('Error: ', exc_info=True) + raise APIException('Segments not fetched properly.') from exc + + +class SegmentsAPIAsync(object): # pylint: disable=too-few-public-methods + """Async Class that uses an httpClient to communicate with the segments API.""" + + def __init__(self, http_client, sdk_key, sdk_metadata, telemetry_runtime_producer): + """ + Class constructor. + + :param client: HTTP Client responsble for issuing calls to the backend. + :type client: client.HttpClient + :param sdk_key: User sdk_key token. + :type sdk_key: string + :param sdk_metadata: SDK version & machine name & IP. + :type sdk_metadata: splitio.client.util.SdkMetadata + + """ + self._client = http_client + self._sdk_key = sdk_key + self._metadata = headers_from_metadata(sdk_metadata) + self._telemetry_runtime_producer = telemetry_runtime_producer + self._client.set_telemetry_data(HTTPExceptionsAndLatencies.SEGMENT, self._telemetry_runtime_producer) + + async def fetch_segment(self, segment_name, change_number, fetch_options): + """ + Fetch splits from backend. + + :param segment_name: Name of the segment to fetch changes for. + :type segment_name: str + + :param change_number: Last known timestamp of a segment modification. + :type change_number: int + :param fetch_options: Fetch options for getting segment definitions. + :type fetch_options: splitio.api.commons.FetchOptions + + :return: Json representation of a segmentChange response. + :rtype: dict + """ + try: + query, extra_headers = build_fetch(change_number, fetch_options, self._metadata) + response = await self._client.get( + 'sdk', + 'segmentChanges/{segment_name}'.format(segment_name=segment_name), + self._sdk_key, + extra_headers=extra_headers, + query=query, + ) if 200 <= response.status_code < 300: return json.loads(response.body) - else: - raise APIException(response.body, response.status_code) + + raise APIException(response.body, response.status_code) except HttpClientException as exc: _LOGGER.error( 'Error fetching %s because an exception was raised by the HTTPClient', diff --git a/splitio/api/splits.py b/splitio/api/splits.py index e395d454..771100fc 100644 --- a/splitio/api/splits.py +++ b/splitio/api/splits.py @@ -3,59 +3,200 @@ import logging import json -from splitio.api import APIException -from splitio.api.commons import headers_from_metadata, build_fetch +from splitio.api import APIException, headers_from_metadata +from splitio.api.commons import build_fetch, FetchOptions from splitio.api.client import HttpClientException - +from splitio.models.telemetry import HTTPExceptionsAndLatencies +from splitio.util.time import utctime_ms +from splitio.spec import SPEC_VERSION +from splitio.sync import util _LOGGER = logging.getLogger(__name__) +_SPEC_1_1 = "1.1" +_PROXY_CHECK_INTERVAL_MILLISECONDS_SS = 24 * 60 * 60 * 1000 - -class SplitsAPI(object): # pylint: disable=too-few-public-methods +class SplitsAPIBase(object): # pylint: disable=too-few-public-methods """Class that uses an httpClient to communicate with the splits API.""" - def __init__(self, client, apikey, sdk_metadata): + def __init__(self, client, sdk_key, sdk_metadata, telemetry_runtime_producer): """ Class constructor. :param client: HTTP Client responsble for issuing calls to the backend. :type client: HttpClient - :param apikey: User apikey token. - :type apikey: string + :param sdk_key: User sdk_key token. + :type sdk_key: string :param sdk_metadata: SDK version & machine name & IP. :type sdk_metadata: splitio.client.util.SdkMetadata """ self._client = client - self._apikey = apikey + self._sdk_key = sdk_key self._metadata = headers_from_metadata(sdk_metadata) + self._telemetry_runtime_producer = telemetry_runtime_producer + self._client.set_telemetry_data(HTTPExceptionsAndLatencies.SPLIT, self._telemetry_runtime_producer) + self._spec_version = SPEC_VERSION + self._last_proxy_check_timestamp = 0 + self.clear_storage = False + self._old_spec_since = None + + def _check_last_proxy_check_timestamp(self, since): + if self._spec_version == _SPEC_1_1 and ((utctime_ms() - self._last_proxy_check_timestamp) >= _PROXY_CHECK_INTERVAL_MILLISECONDS_SS): + _LOGGER.info("Switching to new Feature flag spec (%s) and fetching.", SPEC_VERSION); + self._spec_version = SPEC_VERSION + self._old_spec_since = since + + def _check_old_spec_since(self, change_number): + if self._spec_version == _SPEC_1_1 and self._old_spec_since is not None: + since = self._old_spec_since + self._old_spec_since = None + return since + return change_number + + +class SplitsAPI(SplitsAPIBase): # pylint: disable=too-few-public-methods + """Class that uses an httpClient to communicate with the splits API.""" + + def __init__(self, client, sdk_key, sdk_metadata, telemetry_runtime_producer): + """ + Class constructor. + + :param client: HTTP Client responsble for issuing calls to the backend. + :type client: HttpClient + :param sdk_key: User sdk_key token. + :type sdk_key: string + :param sdk_metadata: SDK version & machine name & IP. + :type sdk_metadata: splitio.client.util.SdkMetadata + """ + SplitsAPIBase.__init__(self, client, sdk_key, sdk_metadata, telemetry_runtime_producer) - def fetch_splits(self, change_number, fetch_options): + def fetch_splits(self, change_number, rbs_change_number, fetch_options): """ - Fetch splits from backend. + Fetch feature flags from backend. :param change_number: Last known timestamp of a split modification. :type change_number: int - :param fetch_options: Fetch options for getting split definitions. + :param rbs_change_number: Last known timestamp of a rule based segment modification. + :type rbs_change_number: int + + :param fetch_options: Fetch options for getting feature flag definitions. :type fetch_options: splitio.api.commons.FetchOptions :return: Json representation of a splitChanges response. :rtype: dict """ try: - query, extra_headers = build_fetch(change_number, fetch_options, self._metadata) + self._check_last_proxy_check_timestamp(change_number) + change_number = self._check_old_spec_since(change_number) + + if self._spec_version == _SPEC_1_1: + fetch_options = FetchOptions(fetch_options.cache_control_headers, fetch_options.change_number, + None, fetch_options.sets, self._spec_version) + rbs_change_number = None + query, extra_headers = build_fetch(change_number, fetch_options, self._metadata, rbs_change_number) response = self._client.get( 'sdk', - '/splitChanges', - self._apikey, + 'splitChanges', + self._sdk_key, + extra_headers=extra_headers, + query=query, + ) + if 200 <= response.status_code < 300: + if self._spec_version == _SPEC_1_1: + return util.convert_to_new_spec(json.loads(response.body)) + + self.clear_storage = self._last_proxy_check_timestamp != 0 + self._last_proxy_check_timestamp = 0 + return json.loads(response.body) + + else: + if response.status_code == 414: + _LOGGER.error('Error fetching feature flags; the amount of flag sets provided are too big, causing uri length error.') + + if self._client.is_sdk_endpoint_overridden() and response.status_code == 400 and self._spec_version == SPEC_VERSION: + _LOGGER.warning('Detected proxy response error, changing spec version from %s to %s and re-fetching.', self._spec_version, _SPEC_1_1) + self._spec_version = _SPEC_1_1 + self._last_proxy_check_timestamp = utctime_ms() + return self.fetch_splits(change_number, None, FetchOptions(fetch_options.cache_control_headers, fetch_options.change_number, + None, fetch_options.sets, self._spec_version)) + + raise APIException(response.body, response.status_code) + + except HttpClientException as exc: + _LOGGER.error('Error fetching feature flags because an exception was raised by the HTTPClient') + _LOGGER.debug('Error: ', exc_info=True) + raise APIException('Feature flags not fetched correctly.') from exc + +class SplitsAPIAsync(SplitsAPIBase): # pylint: disable=too-few-public-methods + """Class that uses an httpClient to communicate with the splits API.""" + + def __init__(self, client, sdk_key, sdk_metadata, telemetry_runtime_producer): + """ + Class constructor. + + :param client: HTTP Client responsble for issuing calls to the backend. + :type client: HttpClient + :param sdk_key: User sdk_key token. + :type sdk_key: string + :param sdk_metadata: SDK version & machine name & IP. + :type sdk_metadata: splitio.client.util.SdkMetadata + """ + SplitsAPIBase.__init__(self, client, sdk_key, sdk_metadata, telemetry_runtime_producer) + + async def fetch_splits(self, change_number, rbs_change_number, fetch_options): + """ + Fetch feature flags from backend. + + :param change_number: Last known timestamp of a split modification. + :type change_number: int + + :param rbs_change_number: Last known timestamp of a rule based segment modification. + :type rbs_change_number: int + + :param fetch_options: Fetch options for getting feature flag definitions. + :type fetch_options: splitio.api.commons.FetchOptions + + :return: Json representation of a splitChanges response. + :rtype: dict + """ + try: + self._check_last_proxy_check_timestamp(change_number) + change_number = self._check_old_spec_since(change_number) + if self._spec_version == _SPEC_1_1: + fetch_options = FetchOptions(fetch_options.cache_control_headers, fetch_options.change_number, + None, fetch_options.sets, self._spec_version) + rbs_change_number = None + + query, extra_headers = build_fetch(change_number, fetch_options, self._metadata, rbs_change_number) + response = await self._client.get( + 'sdk', + 'splitChanges', + self._sdk_key, extra_headers=extra_headers, query=query, ) if 200 <= response.status_code < 300: + if self._spec_version == _SPEC_1_1: + return util.convert_to_new_spec(json.loads(response.body)) + + self.clear_storage = self._last_proxy_check_timestamp != 0 + self._last_proxy_check_timestamp = 0 return json.loads(response.body) + else: + if response.status_code == 414: + _LOGGER.error('Error fetching feature flags; the amount of flag sets provided are too big, causing uri length error.') + + if self._client.is_sdk_endpoint_overridden() and response.status_code == 400 and self._spec_version == SPEC_VERSION: + _LOGGER.warning('Detected proxy response error, changing spec version from %s to %s and re-fetching.', self._spec_version, _SPEC_1_1) + self._spec_version = _SPEC_1_1 + self._last_proxy_check_timestamp = utctime_ms() + return await self.fetch_splits(change_number, None, FetchOptions(fetch_options.cache_control_headers, fetch_options.change_number, + None, fetch_options.sets, self._spec_version)) + raise APIException(response.body, response.status_code) + except HttpClientException as exc: - _LOGGER.error('Error fetching splits because an exception was raised by the HTTPClient') + _LOGGER.error('Error fetching feature flags because an exception was raised by the HTTPClient') _LOGGER.debug('Error: ', exc_info=True) - raise APIException('Splits not fetched correctly.') from exc + raise APIException('Feature flags not fetched correctly.') from exc diff --git a/splitio/api/telemetry.py b/splitio/api/telemetry.py new file mode 100644 index 00000000..48f2ad2d --- /dev/null +++ b/splitio/api/telemetry.py @@ -0,0 +1,187 @@ +"""Impressions API module.""" +import logging + +from splitio.api import APIException, headers_from_metadata +from splitio.api.client import HttpClientException +from splitio.models.telemetry import HTTPExceptionsAndLatencies + +_LOGGER = logging.getLogger(__name__) + +class TelemetryAPI(object): # pylint: disable=too-few-public-methods + """Class that uses an httpClient to communicate with the Telemetry API.""" + + def __init__(self, client, sdk_key, sdk_metadata, telemetry_runtime_producer): + """ + Class constructor. + + :param client: HTTP Client responsble for issuing calls to the backend. + :type client: HttpClient + :param sdk_key: User sdk_key token. + :type sdk_key: string + """ + self._client = client + self._sdk_key = sdk_key + self._metadata = headers_from_metadata(sdk_metadata) + self._telemetry_runtime_producer = telemetry_runtime_producer + self._client.set_telemetry_data(HTTPExceptionsAndLatencies.TELEMETRY, self._telemetry_runtime_producer) + + def record_unique_keys(self, uniques): + """ + Send unique keys to the backend. + + :param uniques: Unique Keys + :type json + """ + try: + response = self._client.post( + 'telemetry', + 'v1/keys/ss', + self._sdk_key, + body=uniques, + extra_headers=self._metadata + ) + if not 200 <= response.status_code < 300: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + _LOGGER.debug( + 'Error posting unique keys because an exception was raised by the HTTPClient' + ) + _LOGGER.debug('Error: ', exc_info=True) + raise APIException('Unique keys not flushed properly.') from exc + + def record_init(self, configs): + """ + Send init config data to the backend. + + :param configs: configs + :type json + """ + try: + response = self._client.post( + 'telemetry', + 'v1/metrics/config', + self._sdk_key, + body=configs, + extra_headers=self._metadata, + ) + if not 200 <= response.status_code < 300: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + _LOGGER.debug( + 'Error posting init config because an exception was raised by the HTTPClient' + ) + _LOGGER.debug('Error: ', exc_info=True) + + def record_stats(self, stats): + """ + Send runtime stats to the backend. + + :param stats: stats + :type json + """ + try: + response = self._client.post( + 'telemetry', + 'v1/metrics/usage', + self._sdk_key, + body=stats, + extra_headers=self._metadata, + ) + if not 200 <= response.status_code < 300: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + _LOGGER.debug( + 'Error posting runtime stats because an exception was raised by the HTTPClient' + ) + _LOGGER.debug('Error: ', exc_info=True) + raise APIException('Runtime stats not flushed properly.') from exc + + +class TelemetryAPIAsync(object): # pylint: disable=too-few-public-methods + """Async Class that uses an httpClient to communicate with the Telemetry API.""" + + def __init__(self, client, sdk_key, sdk_metadata, telemetry_runtime_producer): + """ + Class constructor. + + :param client: HTTP Client responsble for issuing calls to the backend. + :type client: HttpClient + :param sdk_key: User sdk_key token. + :type sdk_key: string + """ + self._client = client + self._sdk_key = sdk_key + self._metadata = headers_from_metadata(sdk_metadata) + self._telemetry_runtime_producer = telemetry_runtime_producer + self._client.set_telemetry_data(HTTPExceptionsAndLatencies.TELEMETRY, self._telemetry_runtime_producer) + + async def record_unique_keys(self, uniques): + """ + Send unique keys to the backend. + + :param uniques: Unique Keys + :type json + """ + try: + response = await self._client.post( + 'telemetry', + 'v1/keys/ss', + self._sdk_key, + body=uniques, + extra_headers=self._metadata + ) + if not 200 <= response.status_code < 300: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + _LOGGER.debug( + 'Error posting unique keys because an exception was raised by the HTTPClient' + ) + _LOGGER.debug('Error: ', exc_info=True) + raise APIException('Unique keys not flushed properly.') from exc + + async def record_init(self, configs): + """ + Send init config data to the backend. + + :param configs: configs + :type json + """ + try: + response = await self._client.post( + 'telemetry', + 'v1/metrics/config', + self._sdk_key, + body=configs, + extra_headers=self._metadata, + ) + if not 200 <= response.status_code < 300: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + _LOGGER.debug( + 'Error posting init config because an exception was raised by the HTTPClient' + ) + _LOGGER.debug('Error: ', exc_info=True) + + async def record_stats(self, stats): + """ + Send runtime stats to the backend. + + :param stats: stats + :type json + """ + try: + response = await self._client.post( + 'telemetry', + 'v1/metrics/usage', + self._sdk_key, + body=stats, + extra_headers=self._metadata, + ) + if not 200 <= response.status_code < 300: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + _LOGGER.debug( + 'Error posting runtime stats because an exception was raised by the HTTPClient' + ) + _LOGGER.debug('Error: ', exc_info=True) + raise APIException('Runtime stats not flushed properly.') from exc diff --git a/splitio/client/client.py b/splitio/client/client.py index 5fca7151..3c61166d 100644 --- a/splitio/client/client.py +++ b/splitio/client/client.py @@ -1,27 +1,47 @@ """A module for Split.io SDK API clients.""" import logging -import time -from splitio.engine.evaluator import Evaluator, CONTROL -from splitio.engine.splitters import Splitter -from splitio.models.impressions import Impression, Label -from splitio.models.events import Event, EventWrapper -from splitio.models.telemetry import get_latency_bucket_index +import json +from collections import namedtuple +import copy + from splitio.client import input_validator -from splitio.util import utctime_ms +from splitio.engine.evaluator import Evaluator, CONTROL, EvaluationDataFactory, AsyncEvaluationDataFactory +from splitio.engine.splitters import Splitter +from splitio.models.impressions import Impression, Label, ImpressionDecorated +from splitio.models.events import Event, EventWrapper, SdkEvent +from splitio.models.telemetry import get_latency_bucket_index, MethodExceptionsAndLatencies +from splitio.optional.loaders import asyncio +from splitio.util.time import get_current_epoch_time_ms, utctime_ms _LOGGER = logging.getLogger(__name__) +EvaluationOptions = namedtuple('EvaluationOptions', ['properties']) -class Client(object): # pylint: disable=too-many-instance-attributes +class ClientBase(object): # pylint: disable=too-many-instance-attributes """Entry point for the split sdk.""" - _METRIC_GET_TREATMENT = 'sdk.getTreatment' - _METRIC_GET_TREATMENTS = 'sdk.getTreatments' - _METRIC_GET_TREATMENT_WITH_CONFIG = 'sdk.getTreatmentWithConfig' - _METRIC_GET_TREATMENTS_WITH_CONFIG = 'sdk.getTreatmentsWithConfig' + _FAILED_EVAL_RESULT = { + 'treatment': CONTROL, + 'configurations': None, + 'impression': { + 'label': Label.EXCEPTION, + 'change_number': None, + }, + 'impressions_disabled': False + } + + _NON_READY_EVAL_RESULT = { + 'treatment': CONTROL, + 'configurations': None, + 'impression': { + 'label': Label.NOT_READY, + 'change_number': None + }, + 'impressions_disabled': False + } - def __init__(self, factory, recorder, labels_enabled=True): + def __init__(self, factory, recorder, events_manager, labels_enabled=True, fallback_treatment_calculator=None): """ Construct a Client instance. @@ -40,18 +60,14 @@ def __init__(self, factory, recorder, labels_enabled=True): self._labels_enabled = labels_enabled self._recorder = recorder self._splitter = Splitter() - self._split_storage = factory._get_storage('splits') # pylint: disable=protected-access + self._feature_flag_storage = factory._get_storage('splits') # pylint: disable=protected-access self._segment_storage = factory._get_storage('segments') # pylint: disable=protected-access self._events_storage = factory._get_storage('events') # pylint: disable=protected-access - self._evaluator = Evaluator(self._split_storage, self._segment_storage, self._splitter) - - def destroy(self): - """ - Destroy the underlying factory. - - Only applicable when using in-memory operation mode. - """ - self._factory.destroy() + self._evaluator = Evaluator(self._splitter, fallback_treatment_calculator) + self._telemetry_evaluation_producer = self._factory._telemetry_evaluation_producer + self._telemetry_init_producer = self._factory._telemetry_init_producer + self._fallback_treatment_calculator = fallback_treatment_calculator + self._events_manager = events_manager @property def ready(self): @@ -63,175 +79,712 @@ def destroyed(self): """Return whether the factory holding this client has been destroyed.""" return self._factory.destroyed - def _evaluate_if_ready(self, matching_key, bucketing_key, feature, attributes=None): - if not self.ready: - return { - 'treatment': CONTROL, - 'configurations': None, - 'impression': { - 'label': Label.NOT_READY, - 'change_number': None - } - } - - return self._evaluator.evaluate_feature( - feature, - matching_key, - bucketing_key, - attributes - ) + def _client_is_usable(self): + if self.destroyed: + _LOGGER.error("Client has already been destroyed - no calls possible") + return False - def _make_evaluation(self, key, feature, attributes, method_name, metric_name): - try: - if self.destroyed: - _LOGGER.error("Client has already been destroyed - no calls possible") - return CONTROL, None - if self._factory._waiting_fork(): - _LOGGER.error("Client is not ready - no calls possible") - return CONTROL, None - - start = int(round(time.time() * 1000)) - - matching_key, bucketing_key = input_validator.validate_key(key, method_name) - feature = input_validator.validate_feature_name( - feature, - self.ready, - self._factory._get_storage('splits'), # pylint: disable=protected-access - method_name - ) - - if (matching_key is None and bucketing_key is None) \ - or feature is None \ - or not input_validator.validate_attributes(attributes, method_name): - return CONTROL, None - - result = self._evaluate_if_ready(matching_key, bucketing_key, feature, attributes) - - impression = self._build_impression( - matching_key, - feature, - result['treatment'], - result['impression']['label'], - result['impression']['change_number'], - bucketing_key, - utctime_ms(), - ) - - self._record_stats([(impression, attributes)], start, metric_name) - return result['treatment'], result['configurations'] - except Exception: # pylint: disable=broad-except - _LOGGER.error('Error getting treatment for feature') - _LOGGER.debug('Error: ', exc_info=True) - try: - impression = self._build_impression( - matching_key, - feature, - CONTROL, - Label.EXCEPTION, - self._split_storage.get_change_number(), - bucketing_key, - utctime_ms(), - ) - self._record_stats([(impression, attributes)], start, metric_name) - except Exception: # pylint: disable=broad-except - _LOGGER.error('Error reporting impression into get_treatment exception block') - _LOGGER.debug('Error: ', exc_info=True) - return CONTROL, None + if self._factory._waiting_fork(): + _LOGGER.error("Client is not ready - no calls possible") + return False + + return True + + @staticmethod + def _validate_treatment_input(key, feature, attributes, method, evaluation_options=None): + """Perform all static validations on user supplied input.""" + matching_key, bucketing_key = input_validator.validate_key(key, 'get_' + method.value) + if not matching_key: + raise _InvalidInputError() + + feature = input_validator.validate_feature_flag_name(feature, 'get_' + method.value) + if not feature: + raise _InvalidInputError() + + if not input_validator.validate_attributes(attributes, 'get_' + method.value): + raise _InvalidInputError() + + evaluation_options = ClientBase._validate_treatment_options('get_' + method.value, evaluation_options) + return matching_key, bucketing_key, feature, attributes, evaluation_options + + @staticmethod + def _validate_treatments_input(key, features, attributes, method, evaluation_options=None): + """Perform all static validations on user supplied input.""" + matching_key, bucketing_key = input_validator.validate_key(key, 'get_' + method.value) + if not matching_key: + raise _InvalidInputError() - def _make_evaluations(self, key, features, attributes, method_name, metric_name): + features = input_validator.validate_feature_flags_get_treatments('get_' + method.value, features) + if not features: + raise _InvalidInputError() + + if not input_validator.validate_attributes(attributes, 'get_' + method.value): + raise _InvalidInputError() + + evaluation_options = ClientBase._validate_treatment_options('get_' + method.value, evaluation_options) + return matching_key, bucketing_key, features, attributes, evaluation_options + + @staticmethod + def _validate_treatment_options(method_name, evaluation_options=None): + evaluation_options = input_validator.validate_evaluation_options(evaluation_options, method_name) + if evaluation_options == None: + return None + + if evaluation_options.properties is not None: + valid, properties, size = input_validator.valid_properties(evaluation_options.properties, method_name) + evaluation_options = EvaluationOptions(properties) + if not valid: + evaluation_options = EvaluationOptions(None) + return evaluation_options + + def _build_impression(self, key, bucketing, feature, result, properties=None): + """Build an impression based on evaluation data & it's result.""" + return ImpressionDecorated( + Impression(matching_key=key, + feature_name=feature, + treatment=result['treatment'], + label=result['impression']['label'] if self._labels_enabled else None, + change_number=result['impression']['change_number'], + bucketing_key=bucketing, + time=utctime_ms(), + previous_time=None, + properties=json.dumps(properties) if properties is not None else None), + disabled=result['impressions_disabled']) + + def _build_impressions(self, key, bucketing, results, properties=None): + """Build an impression based on evaluation data & it's result.""" + return [ + self._build_impression(key, bucketing, feature, result, properties) + for feature, result in results.items() + ] + + def _validate_track(self, key, traffic_type, event_type, value=None, properties=None): + """ + Validate track call parameters + + :param key: user key associated to the event + :type key: str + :param traffic_type: traffic type name + :type traffic_type: str + :param event_type: event type name + :type event_type: str + :param value: (Optional) value associated to the event + :type value: Number + :param properties: (Optional) properties associated to the event + :type properties: dict + + :return: validation, event created and its properties size. + :rtype: tuple(bool, splitio.models.events.Event, int) + """ if self.destroyed: _LOGGER.error("Client has already been destroyed - no calls possible") - return input_validator.generate_control_treatments(features, method_name) + return False, None, None + if self._factory._waiting_fork(): _LOGGER.error("Client is not ready - no calls possible") - return input_validator.generate_control_treatments(features, method_name) + return False, None, None - start = int(round(time.time() * 1000)) - - matching_key, bucketing_key = input_validator.validate_key(key, method_name) - if matching_key is None and bucketing_key is None: - return input_validator.generate_control_treatments(features, method_name) + key = input_validator.validate_track_key(key) + event_type = input_validator.validate_event_type(event_type) + value = input_validator.validate_value(value) + valid, properties, size = input_validator.valid_properties(properties, 'track') - if input_validator.validate_attributes(attributes, method_name) is False: - return input_validator.generate_control_treatments(features, method_name) + if key is None or event_type is None or traffic_type is None or value is False \ + or valid is False: + return False, None, None - features, missing = input_validator.validate_features_get_treatments( - method_name, - features, - self.ready, - self._factory._get_storage('splits') # pylint: disable=protected-access + event = Event( + key=key, + traffic_type_name=traffic_type, + event_type_id=event_type, + value=value, + timestamp=utctime_ms(), + properties=properties, ) - if features is None: + + return True, event, size + + def _get_properties(self, evaluation_options): + return evaluation_options.properties if evaluation_options != None else None + + def _get_fallback_treatment_with_config(self, feature): + fallback_treatment = self._fallback_treatment_calculator.resolve(feature, "") + return fallback_treatment.treatment, fallback_treatment.config + + def _get_fallback_eval_results(self, eval_result, feature): + result = copy.deepcopy(eval_result) + fallback_treatment = self._fallback_treatment_calculator.resolve(feature, result["impression"]["label"]) + result["impression"]["label"] = fallback_treatment.label + result["treatment"] = fallback_treatment.treatment + result["configurations"] = fallback_treatment.config + + return result + + def _check_impression_label(self, result): + return result['impression']['label'] == None or (result['impression']['label'] != None and result['impression']['label'].find(Label.SPLIT_NOT_FOUND) == -1) + + def _validate_sdk_event_info(self, sdk_event, callback_handle): + if not self._check_sdk_event(sdk_event): + return False + + if not hasattr(callback_handle, '__call__'): + _LOGGER.warning("Client Event Subscription: The callback handle passed must be of type function, ignoring event subscribing action.") + return False + + return True + + def _check_sdk_event(self, sdk_event): + if not isinstance(sdk_event, SdkEvent): + _LOGGER.warning("Client Event Subscription: The event passed must be of type SdkEvent, ignoring event subscribing action.") + return False + + return True + +class Client(ClientBase): # pylint: disable=too-many-instance-attributes + """Entry point for the split sdk.""" + + def __init__(self, factory, recorder, events_manager, labels_enabled=True, fallback_treatment_calculator=None): + """ + Construct a Client instance. + + :param factory: Split factory (client & manager container) + :type factory: splitio.client.factory.SplitFactory + + :param labels_enabled: Whether to store labels on impressions + :type labels_enabled: bool + + :param recorder: recorder instance + :type recorder: splitio.recorder.StatsRecorder + + :rtype: Client + """ + ClientBase.__init__(self, factory, recorder, events_manager, labels_enabled, fallback_treatment_calculator) + self._context_factory = EvaluationDataFactory(factory._get_storage('splits'), factory._get_storage('segments'), factory._get_storage('rule_based_segments')) + + def destroy(self): + """ + Destroy the underlying factory. + + Only applicable when using in-memory operation mode. + """ + self._factory.destroy() + + def on(self, sdk_event, callback_handle): + if not self._validate_sdk_event_info(sdk_event, callback_handle): + return + + self._events_manager.register(sdk_event, callback_handle) + + def get_treatment(self, key, feature_flag_name, attributes=None, evaluation_options=None): + """ + Get the treatment for a feature flag and key, with an optional dictionary of attributes. + + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + + :param key: The key for which to get the treatment + :type key: str + :param feature_flag_name: The name of the feature flag for which to get the treatment + :type feature_flag_name: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: The treatment for the key and feature flag + :rtype: str + """ + try: + treatment, _ = self._get_treatment(MethodExceptionsAndLatencies.TREATMENT, key, feature_flag_name, attributes, evaluation_options) + return treatment + + except: + _LOGGER.error('get_treatment failed') + treatment, _ = self._get_fallback_treatment_with_config(feature_flag_name) + return treatment + + def get_treatment_with_config(self, key, feature_flag_name, attributes=None, evaluation_options=None): + """ + Get the treatment and config for a feature flag and key, with optional dictionary of attributes. + + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + + :param key: The key for which to get the treatment + :type key: str + :param feature: The name of the feature flag for which to get the treatment + :type feature: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: The treatment for the key and feature flag + :rtype: tuple(str, str) + """ + try: + return self._get_treatment(MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG, key, feature_flag_name, attributes, evaluation_options) + + except Exception: + _LOGGER.error('get_treatment_with_config failed') + return self._get_fallback_treatment_with_config(feature_flag_name) + + def _get_treatment(self, method, key, feature, attributes=None, evaluation_options=None): + """ + Validate key, feature flag name and object, and get the treatment and config with an optional dictionary of attributes. + + :param key: The key for which to get the treatment + :type key: str + :param feature_flag_name: The name of the feature flag for which to get the treatment + :type feature_flag_name: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param method: The method calling this function + :type method: splitio.models.telemetry.MethodExceptionsAndLatencies + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: The treatment and config for the key and feature flag + :rtype: dict + """ + if not self._client_is_usable(): # not destroyed & not waiting for a fork + return self._get_fallback_treatment_with_config(feature) + + start = get_current_epoch_time_ms() + if not self.ready: + _LOGGER.error("Client is not ready - no calls possible") + self._telemetry_init_producer.record_not_ready_usage() + + try: + key, bucketing, feature, attributes, evaluation_options = self._validate_treatment_input(key, feature, attributes, method, evaluation_options) + except _InvalidInputError: + return self._get_fallback_treatment_with_config(feature) + + result = self._get_fallback_eval_results(self._NON_READY_EVAL_RESULT, feature) + + if self.ready: + try: + ctx = self._context_factory.context_for(key, [feature]) + input_validator.validate_feature_flag_names({feature: ctx.flags.get(feature)}, 'get_' + method.value) + result = self._evaluator.eval_with_context(key, bucketing, feature, attributes, ctx) + except RuntimeError as e: + _LOGGER.error('Error getting treatment for feature flag') + _LOGGER.debug('Error: ', exc_info=True) + self._telemetry_evaluation_producer.record_exception(method) + result = self._get_fallback_eval_results(self._FAILED_EVAL_RESULT, feature) + + properties = self._get_properties(evaluation_options) + if self._check_impression_label(result): + impression_decorated = self._build_impression(key, bucketing, feature, result, properties) + self._record_stats([(impression_decorated, attributes)], start, method) + + return result['treatment'], result['configurations'] + + def get_treatments(self, key, feature_flag_names, attributes=None, evaluation_options=None): + """ + Evaluate multiple feature flags and return a dictionary with all the feature flag/treatments. + + Get the treatments for a list of feature flags considering a key, with an optional dictionary of + attributes. This method never raises an exception. If there's a problem, the appropriate + log message will be generated and the method will return the CONTROL treatment. + :param key: The key for which to get the treatment + :type key: str + :param features: Array of the names of the feature flags for which to get the treatment + :type feature: list + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + try: + with_config = self._get_treatments(key, feature_flag_names, MethodExceptionsAndLatencies.TREATMENTS, attributes, evaluation_options) + return {feature_flag: result[0] for (feature_flag, result) in with_config.items()} + + except Exception: + return {feature: self._get_fallback_treatment_with_config(feature)[0] for feature in feature_flag_names} + + def get_treatments_with_config(self, key, feature_flag_names, attributes=None, evaluation_options=None): + """ + Evaluate multiple feature flags and return a dict with feature flag -> (treatment, config). + + Get the treatments for a list of feature flags considering a key, with an optional dictionary of + attributes. This method never raises an exception. If there's a problem, the appropriate + log message will be generated and the method will return the CONTROL treatment. + :param key: The key for which to get the treatment + :type key: str + :param features: Array of the names of the feature flags for which to get the treatment + :type feature: list + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + try: + return self._get_treatments(key, feature_flag_names, MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG, attributes, evaluation_options) + + except Exception: + return {feature: (self._get_fallback_treatment_with_config(feature)) for feature in feature_flag_names} + + def get_treatments_by_flag_set(self, key, flag_set, attributes=None, evaluation_options=None): + """ + Get treatments for feature flags that contain given flag set. + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + :param key: The key for which to get the treatment + :type key: str + :param flag_set: flag set + :type flag_sets: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + return self._get_treatments_by_flag_sets( key, [flag_set], MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET, attributes, evaluation_options) + + def get_treatments_by_flag_sets(self, key, flag_sets, attributes=None, evaluation_options=None): + """ + Get treatments for feature flags that contain given flag sets. + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + :param key: The key for which to get the treatment + :type key: str + :param flag_sets: list of flag sets + :type flag_sets: list + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + return self._get_treatments_by_flag_sets( key, flag_sets, MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS, attributes, evaluation_options) + + def get_treatments_with_config_by_flag_set(self, key, flag_set, attributes=None, evaluation_options=None): + """ + Get treatments for feature flags that contain given flag set. + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + :param key: The key for which to get the treatment + :type key: str + :param flag_set: flag set + :type flag_sets: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + return self._get_treatments_by_flag_sets( key, [flag_set], MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET, attributes, evaluation_options) + + def get_treatments_with_config_by_flag_sets(self, key, flag_sets, attributes=None, evaluation_options=None): + """ + Get treatments for feature flags that contain given flag set. + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + :param key: The key for which to get the treatment + :type key: str + :param flag_set: flag set + :type flag_sets: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + return self._get_treatments_by_flag_sets( key, flag_sets, MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS, attributes, evaluation_options) + + def _get_treatments_by_flag_sets(self, key, flag_sets, method, attributes=None, evaluation_options=None): + """ + Get treatments for feature flags that contain given flag sets. + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + :param key: The key for which to get the treatment + :type key: str + :param flag_sets: list of flag sets + :type flag_sets: list + :param method: Treatment by flag set method flavor + :type method: splitio.models.telemetry.MethodExceptionsAndLatencies + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + feature_flags_names = self._get_feature_flag_names_by_flag_sets(flag_sets, 'get_' + method.value) + if feature_flags_names == []: + _LOGGER.warning("%s: No valid Flag set or no feature flags found for evaluating treatments", 'get_' + method.value) return {} - bulk_impressions = [] - treatments = {name: (CONTROL, None) for name in missing} + if 'config' in method.value: + return self._get_treatments(key, feature_flags_names, method, attributes, evaluation_options) + + with_config = self._get_treatments(key, feature_flags_names, method, attributes, evaluation_options) + return {feature_flag: result[0] for (feature_flag, result) in with_config.items()} + + def get_treatments_by_flag_set(self, key, flag_set, attributes=None, evaluation_options=None): + """ + Get treatments for feature flags that contain given flag set. + + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + + :param key: The key for which to get the treatment + :type key: str + :param flag_set: flag set + :type flag_sets: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + return self._get_treatments_by_flag_sets( key, [flag_set], MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET, attributes, evaluation_options) + + def get_treatments_by_flag_sets(self, key, flag_sets, attributes=None, evaluation_options=None): + """ + Get treatments for feature flags that contain given flag sets. + + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + + :param key: The key for which to get the treatment + :type key: str + :param flag_sets: list of flag sets + :type flag_sets: list + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + return self._get_treatments_by_flag_sets( key, flag_sets, MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS, attributes, evaluation_options) + + def get_treatments_with_config_by_flag_set(self, key, flag_set, attributes=None, evaluation_options=None): + """ + Get treatments for feature flags that contain given flag set. + + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + + :param key: The key for which to get the treatment + :type key: str + :param flag_set: flag set + :type flag_sets: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + return self._get_treatments_by_flag_sets( key, [flag_set], MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET, attributes, evaluation_options) + + def get_treatments_with_config_by_flag_sets(self, key, flag_sets, attributes=None, evaluation_options=None): + """ + Get treatments for feature flags that contain given flag set. + + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + + :param key: The key for which to get the treatment + :type key: str + :param flag_set: flag set + :type flag_sets: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + return self._get_treatments_by_flag_sets( key, flag_sets, MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS, attributes, evaluation_options) + + def _get_feature_flag_names_by_flag_sets(self, flag_sets, method_name): + """ + Sanitize given flag sets and return list of feature flag names associated with them + :param flag_sets: list of flag sets + :type flag_sets: list + + :return: list of feature flag names + :rtype: list + """ + sanitized_flag_sets = input_validator.validate_flag_sets(flag_sets, method_name) + feature_flags_by_set = self._feature_flag_storage.get_feature_flags_by_sets(sanitized_flag_sets) + if feature_flags_by_set is None: + _LOGGER.warning("Fetching feature flags for flag set %s encountered an error, skipping this flag set." % (flag_sets)) + return [] + + return feature_flags_by_set + + def _get_treatments(self, key, features, method, attributes=None, evaluation_options=None): + """ + Validate key, feature flag names and objects, and get the treatments and configs with an optional dictionary of attributes. + + :param key: The key for which to get the treatment + :type key: str + :param feature_flag_names: Array of feature flag names for which to get the treatments + :type feature_flag_names: list(str) + :param method: The method calling this function + :type method: splitio.models.telemetry.MethodExceptionsAndLatencies + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + + :return: The treatments and configs for the key and feature flags + :rtype: dict + """ + start = get_current_epoch_time_ms() + if not self._client_is_usable(): + return input_validator.generate_control_treatments(features, self._fallback_treatment_calculator) + + if not self.ready: + _LOGGER.error("Client is not ready - no calls possible") + self._telemetry_init_producer.record_not_ready_usage() try: - evaluations = self._evaluate_features_if_ready(matching_key, bucketing_key, - list(features), attributes) - - for feature in features: - try: - result = evaluations[feature] - impression = self._build_impression(matching_key, - feature, - result['treatment'], - result['impression']['label'], - result['impression']['change_number'], - bucketing_key, - utctime_ms()) - - bulk_impressions.append(impression) - treatments[feature] = (result['treatment'], result['configurations']) - - except Exception: # pylint: disable=broad-except - _LOGGER.error('%s: An exception occured when evaluating ' - 'feature %s returning CONTROL.' % (method_name, feature)) - treatments[feature] = CONTROL, None - _LOGGER.debug('Error: ', exc_info=True) - continue - - # Register impressions + key, bucketing, features, attributes, evaluation_options = self._validate_treatments_input(key, features, attributes, method, evaluation_options) + except _InvalidInputError: + return input_validator.generate_control_treatments(features, self._fallback_treatment_calculator) + + results = {n: self._get_fallback_eval_results(self._NON_READY_EVAL_RESULT, n) for n in features} + if self.ready: try: - if bulk_impressions: - self._record_stats( - [(i, attributes) for i in bulk_impressions], - start, - metric_name - ) - except Exception: # pylint: disable=broad-except - _LOGGER.error('%s: An exception when trying to store ' - 'impressions.' % method_name) + ctx = self._context_factory.context_for(key, features) + input_validator.validate_feature_flag_names({feature: ctx.flags.get(feature) for feature in features}, 'get_' + method.value) + results = self._evaluator.eval_many_with_context(key, bucketing, features, attributes, ctx) + except RuntimeError as e: + _LOGGER.error('Error getting treatment for feature flag') _LOGGER.debug('Error: ', exc_info=True) + self._telemetry_evaluation_producer.record_exception(method) + results = {n: self._get_fallback_eval_results(self._FAILED_EVAL_RESULT, n) for n in features} - return treatments - except Exception: # pylint: disable=broad-except - _LOGGER.error('Error getting treatment for features') - _LOGGER.debug('Error: ', exc_info=True) - return input_validator.generate_control_treatments(list(features), method_name) + properties = self._get_properties(evaluation_options) + imp_decorated_attrs = [ + (i, attributes) for i in self._build_impressions(key, bucketing, results, properties) + if i.Impression.label == None or (i.Impression.label != None and i.Impression.label.find(Label.SPLIT_NOT_FOUND)) == -1 + ] + self._record_stats(imp_decorated_attrs, start, method) + + return { + feature: (results[feature]['treatment'], results[feature]['configurations']) + for feature in results + } + + def _record_stats(self, impressions_decorated, start, operation): + """ + Record impressions. + + :param impressions_decorated: Generated impressions + :type impressions_decorated: list[tuple[splitio.models.impression.ImpressionDecorated, dict]] + + :param start: timestamp when get_treatment or get_treatments was called + :type start: int + + :param operation: operation performed. + :type operation: str + """ + end = get_current_epoch_time_ms() + self._recorder.record_treatment_stats(impressions_decorated, get_latency_bucket_index(end - start), + operation, 'get_' + operation.value) + + def track(self, key, traffic_type, event_type, value=None, properties=None): + """ + Track an event. - def _evaluate_features_if_ready(self, matching_key, bucketing_key, features, attributes=None): + :param key: user key associated to the event + :type key: str + :param traffic_type: traffic type name + :type traffic_type: str + :param event_type: event type name + :type event_type: str + :param value: (Optional) value associated to the event + :type value: Number + :param properties: (Optional) properties associated to the event + :type properties: dict + + :return: Whether the event was created or not. + :rtype: bool + """ if not self.ready: - return { - feature: { - 'treatment': CONTROL, - 'configurations': None, - 'impression': {'label': Label.NOT_READY, 'change_number': None} - } - for feature in features - } - - return self._evaluator.evaluate_features( - features, - matching_key, - bucketing_key, - attributes + _LOGGER.warning("track: the SDK is not ready, results may be incorrect. Make sure to wait for SDK readiness before using this method") + self._telemetry_init_producer.record_not_ready_usage() + + start = get_current_epoch_time_ms() + should_validate_existance = self.ready and self._factory._sdk_key != 'localhost' # pylint: disable=protected-access + traffic_type = input_validator.validate_traffic_type( + traffic_type, + should_validate_existance, + self._factory._get_storage('splits'), # pylint: disable=protected-access ) + is_valid, event, size = self._validate_track(key, traffic_type, event_type, value, properties) + if not is_valid: + return False - def get_treatment_with_config(self, key, feature, attributes=None): + try: + return_flag = self._recorder.record_track_stats([EventWrapper( + event=event, + size=size, + )], get_latency_bucket_index(get_current_epoch_time_ms() - start)) + return return_flag + + except Exception: # pylint: disable=broad-except + self._telemetry_evaluation_producer.record_exception(MethodExceptionsAndLatencies.TRACK) + _LOGGER.error('Error processing track event') + _LOGGER.debug('Error: ', exc_info=True) + return False + + +class ClientAsync(ClientBase): # pylint: disable=too-many-instance-attributes + """Entry point for the split sdk.""" + + def __init__(self, factory, recorder, events_manager, labels_enabled=True, fallback_treatment_calculator=None): """ - Get the treatment and config for a feature and key, with optional dictionary of attributes. + Construct a Client instance. + + :param factory: Split factory (client & manager container) + :type factory: splitio.client.factory.SplitFactory + + :param labels_enabled: Whether to store labels on impressions + :type labels_enabled: bool + + :param recorder: recorder instance + :type recorder: splitio.recorder.StatsRecorder + + :rtype: Client + """ + ClientBase.__init__(self, factory, recorder, events_manager, labels_enabled, fallback_treatment_calculator) + self._context_factory = AsyncEvaluationDataFactory(factory._get_storage('splits'), factory._get_storage('segments'), factory._get_storage('rule_based_segments')) + + async def destroy(self): + """ + Destroy the underlying factory. + + Only applicable when using in-memory operation mode. + """ + await self._factory.destroy() + + async def on(self, sdk_event, callback_handle): + if not self._validate_sdk_event_info(sdk_event, callback_handle): + return + + await self._events_manager.register(sdk_event, callback_handle) + + async def get_treatment(self, key, feature_flag_name, attributes=None, evaluation_options=None): + """ + Get the treatment for a feature and key, with an optional dictionary of attributes, for async calls This method never raises an exception. If there's a problem, the appropriate log message will be generated and the method will return the CONTROL treatment. @@ -242,15 +795,23 @@ def get_treatment_with_config(self, key, feature, attributes=None): :type feature: str :param attributes: An optional dictionary of attributes :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict :return: The treatment for the key and feature - :rtype: tuple(str, str) + :rtype: str """ - return self._make_evaluation(key, feature, attributes, 'get_treatment_with_config', - self._METRIC_GET_TREATMENT_WITH_CONFIG) + try: + treatment, _ = await self._get_treatment(MethodExceptionsAndLatencies.TREATMENT, key, feature_flag_name, attributes, evaluation_options) + return treatment - def get_treatment(self, key, feature, attributes=None): + except: + _LOGGER.error('get_treatment failed') + treatment, _ = self._get_fallback_treatment_with_config(feature_flag_name) + return treatment + + async def get_treatment_with_config(self, key, feature_flag_name, attributes=None, evaluation_options=None): """ - Get the treatment for a feature and key, with an optional dictionary of attributes. + Get the treatment for a feature and key, with an optional dictionary of attributes, for async calls This method never raises an exception. If there's a problem, the appropriate log message will be generated and the method will return the CONTROL treatment. @@ -261,78 +822,292 @@ def get_treatment(self, key, feature, attributes=None): :type feature: str :param attributes: An optional dictionary of attributes :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict :return: The treatment for the key and feature :rtype: str """ - treatment, _ = self._make_evaluation(key, feature, attributes, 'get_treatment', - self._METRIC_GET_TREATMENT) - return treatment + try: + return await self._get_treatment(MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG, key, feature_flag_name, attributes, evaluation_options) + + except Exception: + _LOGGER.error('get_treatment_with_config failed') + return self._get_fallback_treatment_with_config(feature_flag_name) + + async def _get_treatment(self, method, key, feature, attributes=None, evaluation_options=None): + """ + Validate key, feature flag name and object, and get the treatment and config with an optional dictionary of attributes, for async calls + + :param key: The key for which to get the treatment + :type key: str + :param feature_flag_name: The name of the feature flag for which to get the treatment + :type feature_flag_name: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param method: The method calling this function + :type method: splitio.models.telemetry.MethodExceptionsAndLatencies + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: The treatment and config for the key and feature flag + :rtype: dict + """ + if not self._client_is_usable(): # not destroyed & not waiting for a fork + return self._get_fallback_treatment_with_config(feature) + + start = get_current_epoch_time_ms() + if not self.ready: + _LOGGER.error("Client is not ready - no calls possible") + await self._telemetry_init_producer.record_not_ready_usage() - def get_treatments_with_config(self, key, features, attributes=None): + try: + key, bucketing, feature, attributes, evaluation_options = self._validate_treatment_input(key, feature, attributes, method, evaluation_options) + except _InvalidInputError: + return self._get_fallback_treatment_with_config(feature) + + result = self._get_fallback_eval_results(self._NON_READY_EVAL_RESULT, feature) + if self.ready: + try: + ctx = await self._context_factory.context_for(key, [feature]) + input_validator.validate_feature_flag_names({feature: ctx.flags.get(feature)}, 'get_' + method.value) + result = self._evaluator.eval_with_context(key, bucketing, feature, attributes, ctx) + except Exception as e: # toto narrow this + _LOGGER.error('Error getting treatment for feature flag') + _LOGGER.debug('Error: ', exc_info=True) + await self._telemetry_evaluation_producer.record_exception(method) + result = self._get_fallback_eval_results(self._FAILED_EVAL_RESULT, feature) + + properties = self._get_properties(evaluation_options) + if self._check_impression_label(result): + impression_decorated = self._build_impression(key, bucketing, feature, result, properties) + await self._record_stats([(impression_decorated, attributes)], start, method) + return result['treatment'], result['configurations'] + + async def get_treatments(self, key, feature_flag_names, attributes=None, evaluation_options=None): """ - Evaluate multiple features and return a dict with feature -> (treatment, config). + Evaluate multiple feature flags and return a dictionary with all the feature flag/treatments, for async calls - Get the treatments for a list of features considering a key, with an optional dictionary of + Get the treatments for a list of feature flags considering a key, with an optional dictionary of attributes. This method never raises an exception. If there's a problem, the appropriate log message will be generated and the method will return the CONTROL treatment. :param key: The key for which to get the treatment :type key: str - :param features: Array of the names of the features for which to get the treatment + :param features: Array of the names of the feature flags for which to get the treatment :type feature: list :param attributes: An optional dictionary of attributes :type attributes: dict - :return: Dictionary with the result of all the features provided + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: Dictionary with the result of all the feature flags provided :rtype: dict """ - return self._make_evaluations(key, features, attributes, 'get_treatments_with_config', - self._METRIC_GET_TREATMENTS_WITH_CONFIG) + try: + with_config = await self._get_treatments(key, feature_flag_names, MethodExceptionsAndLatencies.TREATMENTS, attributes, evaluation_options) + return {feature_flag: result[0] for (feature_flag, result) in with_config.items()} + + except Exception: + return {feature: self._get_fallback_treatment_with_config(feature)[0] for feature in feature_flag_names} - def get_treatments(self, key, features, attributes=None): + async def get_treatments_with_config(self, key, feature_flag_names, attributes=None, evaluation_options=None): """ - Evaluate multiple features and return a dictionary with all the feature/treatments. + Evaluate multiple feature flags and return a dict with feature flag -> (treatment, config), for async calls - Get the treatments for a list of features considering a key, with an optional dictionary of + Get the treatments for a list of feature flags considering a key, with an optional dictionary of attributes. This method never raises an exception. If there's a problem, the appropriate log message will be generated and the method will return the CONTROL treatment. :param key: The key for which to get the treatment :type key: str - :param features: Array of the names of the features for which to get the treatment + :param features: Array of the names of the feature flags for which to get the treatment :type feature: list :param attributes: An optional dictionary of attributes :type attributes: dict - :return: Dictionary with the result of all the features provided + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: Dictionary with the result of all the feature flags provided :rtype: dict """ - with_config = self._make_evaluations(key, features, attributes, 'get_treatments', - self._METRIC_GET_TREATMENTS) - return {feature: result[0] for (feature, result) in with_config.items()} - - def _build_impression( # pylint: disable=too-many-arguments - self, - matching_key, - feature_name, - treatment, - label, - change_number, - bucketing_key, - imp_time - ): - """Build an impression.""" - if not self._labels_enabled: - label = None - - return Impression( - matching_key=matching_key, feature_name=feature_name, - treatment=treatment, label=label, change_number=change_number, - bucketing_key=bucketing_key, time=imp_time - ) + try: + return await self._get_treatments(key, feature_flag_names, MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG, attributes, evaluation_options) - def _record_stats(self, impressions, start, operation): + except Exception: + return {feature: (self._get_fallback_treatment_with_config(feature)) for feature in feature_flag_names} + + async def get_treatments_by_flag_set(self, key, flag_set, attributes=None, evaluation_options=None): """ - Record impressions. + Get treatments for feature flags that contain given flag set. + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + :param key: The key for which to get the treatment + :type key: str + :param flag_set: flag set + :type flag_sets: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + return await self._get_treatments_by_flag_sets( key, [flag_set], MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET, attributes, evaluation_options) + + async def get_treatments_by_flag_sets(self, key, flag_sets, attributes=None, evaluation_options=None): + """ + Get treatments for feature flags that contain given flag sets. + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + :param key: The key for which to get the treatment + :type key: str + :param flag_sets: list of flag sets + :type flag_sets: list + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + return await self._get_treatments_by_flag_sets( key, flag_sets, MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS, attributes, evaluation_options) - :param impressions: Generated impressions - :type impressions: list[tuple[splitio.models.impression.Impression, dict]] + async def get_treatments_with_config_by_flag_set(self, key, flag_set, attributes=None, evaluation_options=None): + """ + Get treatments for feature flags that contain given flag set. + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + :param key: The key for which to get the treatment + :type key: str + :param flag_set: flag set + :type flag_sets: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + return await self._get_treatments_by_flag_sets( key, [flag_set], MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET, attributes, evaluation_options) + + async def get_treatments_with_config_by_flag_sets(self, key, flag_sets, attributes=None, evaluation_options=None): + """ + Get treatments for feature flags that contain given flag set. + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + :param key: The key for which to get the treatment + :type key: str + :param flag_set: flag set + :type flag_sets: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + return await self._get_treatments_by_flag_sets( key, flag_sets, MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS, attributes, evaluation_options) + + async def _get_treatments_by_flag_sets(self, key, flag_sets, method, attributes=None, evaluation_options=None): + """ + Get treatments for feature flags that contain given flag sets. + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + :param key: The key for which to get the treatment + :type key: str + :param flag_sets: list of flag sets + :type flag_sets: list + :param method: Treatment by flag set method flavor + :type method: splitio.models.telemetry.MethodExceptionsAndLatencies + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: Dictionary with the result of all the feature flags provided + :rtype: dict + """ + feature_flags_names = await self._get_feature_flag_names_by_flag_sets(flag_sets, 'get_' + method.value) + if feature_flags_names == []: + _LOGGER.warning("%s: No valid Flag set or no feature flags found for evaluating treatments", 'get_' + method.value) + return {} + + if 'config' in method.value: + return await self._get_treatments(key, feature_flags_names, method, attributes, evaluation_options) + + with_config = await self._get_treatments(key, feature_flags_names, method, attributes, evaluation_options) + return {feature_flag: result[0] for (feature_flag, result) in with_config.items()} + + async def _get_feature_flag_names_by_flag_sets(self, flag_sets, method_name): + """ + Sanitize given flag sets and return list of feature flag names associated with them + :param flag_sets: list of flag sets + :type flag_sets: list + :return: list of feature flag names + :rtype: list + """ + sanitized_flag_sets = input_validator.validate_flag_sets(flag_sets, method_name) + feature_flags_by_set = await self._feature_flag_storage.get_feature_flags_by_sets(sanitized_flag_sets) + if feature_flags_by_set is None: + _LOGGER.warning("Fetching feature flags for flag set %s encountered an error, skipping this flag set." % (flag_sets)) + return [] + + return feature_flags_by_set + + async def _get_treatments(self, key, features, method, attributes=None, evaluation_options=None): + """ + Validate key, feature flag names and objects, and get the treatments and configs with an optional dictionary of attributes, for async calls + + :param key: The key for which to get the treatment + :type key: str + :param feature_flag_names: Array of feature flag names for which to get the treatments + :type feature_flag_names: list(str) + :param method: The method calling this function + :type method: splitio.models.telemetry.MethodExceptionsAndLatencies + :param attributes: An optional dictionary of attributes + :type attributes: dict + :param evaluation_options: An optional dictionary of options + :type evaluation_options: dict + :return: The treatments and configs for the key and feature flags + :rtype: dict + """ + start = get_current_epoch_time_ms() + if not self._client_is_usable(): + return input_validator.generate_control_treatments(features, self._fallback_treatment_calculator) + + if not self.ready: + _LOGGER.error("Client is not ready - no calls possible") + await self._telemetry_init_producer.record_not_ready_usage() + + try: + key, bucketing, features, attributes, evaluation_options = self._validate_treatments_input(key, features, attributes, method, evaluation_options) + except _InvalidInputError: + return input_validator.generate_control_treatments(features, self._fallback_treatment_calculator) + + results = {n: self._get_fallback_eval_results(self._NON_READY_EVAL_RESULT, n) for n in features} + if self.ready: + try: + ctx = await self._context_factory.context_for(key, features) + input_validator.validate_feature_flag_names({feature: ctx.flags.get(feature) for feature in features}, 'get_' + method.value) + results = self._evaluator.eval_many_with_context(key, bucketing, features, attributes, ctx) + except Exception as e: # toto narrow this + _LOGGER.error('Error getting treatment for feature flag') + _LOGGER.debug('Error: ', exc_info=True) + await self._telemetry_evaluation_producer.record_exception(method) + results = {n: self._get_fallback_eval_results(self._FAILED_EVAL_RESULT, n) for n in features} + + properties = self._get_properties(evaluation_options) + imp_decorated_attrs = [ + (i, attributes) for i in self._build_impressions(key, bucketing, results, properties) + if i.Impression.label == None or (i.Impression.label != None and i.Impression.label.find(Label.SPLIT_NOT_FOUND)) == -1 + ] + await self._record_stats(imp_decorated_attrs, start, method) + + return { + feature: (res['treatment'], res['configurations']) + for feature, res in results.items() + } + + async def _record_stats(self, impressions_decorated, start, operation): + """ + Record impressions for async calls + + :param impressions_decorated: Generated impressions decorated + :type impressions_decorated: list[tuple[splitio.models.impression.Impression, dict]] :param start: timestamp when get_treatment or get_treatments was called :type start: int @@ -340,13 +1115,13 @@ def _record_stats(self, impressions, start, operation): :param operation: operation performed. :type operation: str """ - end = int(round(time.time() * 1000)) - self._recorder.record_treatment_stats(impressions, get_latency_bucket_index(end - start), - operation) + end = get_current_epoch_time_ms() + await self._recorder.record_treatment_stats(impressions_decorated, get_latency_bucket_index(end - start), + operation, 'get_' + operation.value) - def track(self, key, traffic_type, event_type, value=None, properties=None): + async def track(self, key, traffic_type, event_type, value=None, properties=None): """ - Track an event. + Track an event for async calls :param key: user key associated to the event :type key: str @@ -362,38 +1137,33 @@ def track(self, key, traffic_type, event_type, value=None, properties=None): :return: Whether the event was created or not. :rtype: bool """ - if self.destroyed: - _LOGGER.error("Client has already been destroyed - no calls possible") - return False - if self._factory._waiting_fork(): - _LOGGER.error("Client is not ready - no calls possible") - return False + if not self.ready: + _LOGGER.warning("track: the SDK is not ready, results may be incorrect. Make sure to wait for SDK readiness before using this method") + await self._telemetry_init_producer.record_not_ready_usage() - key = input_validator.validate_track_key(key) - event_type = input_validator.validate_event_type(event_type) - should_validate_existance = self.ready and self._factory._apikey != 'localhost' # pylint: disable=protected-access - traffic_type = input_validator.validate_traffic_type( + start = get_current_epoch_time_ms() + should_validate_existance = self.ready and self._factory._sdk_key != 'localhost' # pylint: disable=protected-access + traffic_type = await input_validator.validate_traffic_type_async( traffic_type, should_validate_existance, self._factory._get_storage('splits'), # pylint: disable=protected-access ) + is_valid, event, size = self._validate_track(key, traffic_type, event_type, value, properties) + if not is_valid: + return False - value = input_validator.validate_value(value) - valid, properties, size = input_validator.valid_properties(properties) + try: + return_flag = await self._recorder.record_track_stats([EventWrapper( + event=event, + size=size, + )], get_latency_bucket_index(get_current_epoch_time_ms() - start)) + return return_flag - if key is None or event_type is None or traffic_type is None or value is False \ - or valid is False: + except Exception: # pylint: disable=broad-except + await self._telemetry_evaluation_producer.record_exception(MethodExceptionsAndLatencies.TRACK) + _LOGGER.error('Error processing track event') + _LOGGER.debug('Error: ', exc_info=True) return False - event = Event( - key=key, - traffic_type_name=traffic_type, - event_type_id=event_type, - value=value, - timestamp=utctime_ms(), - properties=properties, - ) - return self._recorder.record_track_stats([EventWrapper( - event=event, - size=size, - )]) +class _InvalidInputError(Exception): + pass diff --git a/splitio/client/config.py b/splitio/client/config.py index 6b40a2c7..25b1bc31 100644 --- a/splitio/client/config.py +++ b/splitio/client/config.py @@ -1,21 +1,28 @@ """Default settings for the Split.IO SDK Python client.""" import os.path import logging +from enum import Enum from splitio.engine.impressions import ImpressionsMode - +from splitio.client.input_validator import validate_flag_sets, validate_fallback_treatment, validate_regex_name +from splitio.models.fallback_config import FallbackTreatmentsConfiguration _LOGGER = logging.getLogger(__name__) DEFAULT_DATA_SAMPLING = 1 +class AuthenticateScheme(Enum): + """Authentication Scheme.""" + NONE = 'NONE' + KERBEROS_SPNEGO = 'KERBEROS_SPNEGO' + KERBEROS_PROXY = 'KERBEROS_PROXY' DEFAULT_CONFIG = { - 'operationMode': 'in-memory', + 'operationMode': 'standalone', 'connectionTimeout': 1500, 'streamingEnabled': True, 'featuresRefreshRate': 30, 'segmentsRefreshRate': 30, - 'metricsRefreshRate': 60, + 'metricsRefreshRate': 3600, 'impressionsRefreshRate': 5 * 60, 'impressionsBulkSize': 5000, 'impressionsQueueSize': 10000, @@ -31,6 +38,7 @@ 'redisHost': 'localhost', 'redisPort': 6379, 'redisDb': 0, + 'redisUsername': None, 'redisPassword': None, 'redisSocketTimeout': None, 'redisSocketConnectTimeout': None, @@ -52,31 +60,51 @@ 'machineName': None, 'machineIp': None, 'splitFile': os.path.join(os.path.expanduser('~'), '.split'), + 'segmentDirectory': os.path.expanduser('~'), + 'localhostRefreshEnabled': False, 'preforkedInitialization': False, 'dataSampling': DEFAULT_DATA_SAMPLING, + 'storageWrapper': None, + 'storagePrefix': None, + 'storageType': None, + 'flagSetsFilter': None, + 'httpAuthenticateScheme': AuthenticateScheme.NONE, + 'kerberosPrincipalUser': None, + 'kerberosPrincipalPassword': None, + 'fallbackTreatments': FallbackTreatmentsConfiguration(None) } - -def _parse_operation_mode(apikey, config): +def _parse_operation_mode(sdk_key, config): """ - Process incoming config to determine operation mode. + Process incoming config to determine operation mode and storage type :param config: user supplied config :type config: dict - :returns: operation mode - :rtype: str + :returns: operation mode and storage type + :rtype: Tuple (str, str) """ - if apikey == 'localhost': - return 'localhost-standalone' + if sdk_key == 'localhost': + _LOGGER.debug('Using Localhost operation mode') + return 'localhost', 'localhost' if 'redisHost' in config or 'redisSentinels' in config: - return 'redis-consumer' + _LOGGER.debug('Using Redis storage operation mode') + return 'consumer', 'redis' + + if config.get('storageType') is not None: + if config.get('storageType').lower() == 'pluggable': + _LOGGER.debug('Using Pluggable storage operation mode') + return 'consumer', 'pluggable' - return 'inmemory-standalone' + _LOGGER.warning('You passed an invalid storageType, acceptable value is ' + '`pluggable`. Defaulting storage to In-Memory mode.') + _LOGGER.debug('Using In-Memory operation mode') + return 'standalone', 'memory' -def _sanitize_impressions_mode(mode, refresh_rate=None): + +def _sanitize_impressions_mode(storage_type, mode, refresh_rate=None): """ Check supplied impressions mode and adjust refresh rate. @@ -90,10 +118,10 @@ def _sanitize_impressions_mode(mode, refresh_rate=None): try: mode = ImpressionsMode(mode.upper()) except (ValueError, AttributeError): - _LOGGER.warning('You passed an invalid impressionsMode, impressionsMode should be ' - 'one of the following values: `debug` or `optimized`. ' - 'Defaulting to `optimized` mode.') mode = ImpressionsMode.OPTIMIZED + _LOGGER.warning('You passed an invalid impressionsMode, impressionsMode should be ' \ + 'one of the following values: `debug`, `none` or `optimized`. ' + ' Defaulting to `optimized` mode.') if mode == ImpressionsMode.DEBUG: refresh_rate = max(1, refresh_rate) if refresh_rate is not None else 60 @@ -102,13 +130,12 @@ def _sanitize_impressions_mode(mode, refresh_rate=None): return mode, refresh_rate - -def sanitize(apikey, config): +def sanitize(sdk_key, config): """ Look for inconsistencies or ill-formed configs and tune it accordingly. - :param apikey: customer's apikey - :type apikey: str + :param sdk_key: sdk key + :type sdk_key: str :param config: DEFAULT + user supplied config :type config: dict @@ -116,11 +143,65 @@ def sanitize(apikey, config): :returns: sanitized config :rtype: dict """ - config['operationMode'] = _parse_operation_mode(apikey, config) + config['operationMode'], config['storageType'] = _parse_operation_mode(sdk_key, config) processed = DEFAULT_CONFIG.copy() processed.update(config) - imp_mode, imp_rate = _sanitize_impressions_mode(config.get('impressionsMode'), + imp_mode, imp_rate = _sanitize_impressions_mode(config['storageType'], config.get('impressionsMode'), config.get('impressionsRefreshRate')) processed['impressionsMode'] = imp_mode processed['impressionsRefreshRate'] = imp_rate + if processed['metricsRefreshRate'] < 60: + _LOGGER.warning('metricRefreshRate parameter minimum value is 60 seconds, defaulting to 3600 seconds.') + processed['metricsRefreshRate'] = 3600 + + if config['operationMode'] == 'consumer' and config.get('flagSetsFilter') is not None: + processed['flagSetsFilter'] = None + _LOGGER.warning('config: FlagSets filter is not applicable for Consumer modes where the SDK does keep rollout data in sync. FlagSet filter was discarded.') + else: + processed['flagSetsFilter'] = sorted(validate_flag_sets(processed['flagSetsFilter'], 'SDK Config')) if processed['flagSetsFilter'] is not None else None + + if config.get('httpAuthenticateScheme') is not None: + try: + authenticate_scheme = AuthenticateScheme(config['httpAuthenticateScheme'].upper()) + except (ValueError, AttributeError): + authenticate_scheme = AuthenticateScheme.NONE + _LOGGER.warning('You passed an invalid HttpAuthenticationScheme, HttpAuthenticationScheme should be ' \ + 'one of the following values: `none`, `kerberos_proxy` or `kerberos_spnego`. ' + ' Defaulting to `none` mode.') + processed["httpAuthenticateScheme"] = authenticate_scheme + + processed = _sanitize_fallback_config(config, processed) + + if config.get("redisErrors") is not None: + _LOGGER.warning('Parameter `redisErrors` is deprecated as it is no longer supported in redis lib.' \ + ' Will ignore this value.') + + processed["redisErrors"] = None return processed + +def _sanitize_fallback_config(config, processed): + if config.get('fallbackTreatments') is None: + return processed + + if not isinstance(config['fallbackTreatments'], FallbackTreatmentsConfiguration): + _LOGGER.warning('Config: fallbackTreatments parameter should be of `FallbackTreatmentsConfiguration` class.') + processed['fallbackTreatments'] = None + return processed + + sanitized_global_fallback_treatment = config['fallbackTreatments'].global_fallback_treatment + if config['fallbackTreatments'].global_fallback_treatment is not None and not validate_fallback_treatment(config['fallbackTreatments'].global_fallback_treatment): + _LOGGER.warning('Config: global fallbacktreatment parameter is discarded.') + sanitized_global_fallback_treatment = None + + sanitized_flag_fallback_treatments = {} + if config['fallbackTreatments'].by_flag_fallback_treatment is not None: + for feature_name in config['fallbackTreatments'].by_flag_fallback_treatment.keys(): + if not validate_regex_name(feature_name) or not validate_fallback_treatment(config['fallbackTreatments'].by_flag_fallback_treatment[feature_name]): + _LOGGER.warning('Config: fallback treatment parameter for feature flag %s is discarded.', feature_name) + continue + + sanitized_flag_fallback_treatments[feature_name] = config['fallbackTreatments'].by_flag_fallback_treatment[feature_name] + + processed['fallbackTreatments'] = FallbackTreatmentsConfiguration(sanitized_global_fallback_treatment, sanitized_flag_fallback_treatments) + + return processed \ No newline at end of file diff --git a/splitio/client/factory.py b/splitio/client/factory.py index bc1827d9..10979b85 100644 --- a/splitio/client/factory.py +++ b/splitio/client/factory.py @@ -2,58 +2,96 @@ import logging import threading from collections import Counter - from enum import Enum +import queue -from splitio.client.client import Client +from splitio.optional.loaders import asyncio +from splitio.client.client import Client, ClientAsync from splitio.client import input_validator -from splitio.client.manager import SplitManager -from splitio.client.config import sanitize as sanitize_config, DEFAULT_DATA_SAMPLING +from splitio.client.config import sanitize as sanitize_config, DEFAULT_DATA_SAMPLING, AuthenticateScheme +from splitio.client.manager import SplitManager, SplitManagerAsync from splitio.client import util -from splitio.client.listener import ImpressionListenerWrapper -from splitio.engine.impressions import Manager as ImpressionsManager +from splitio.client.listener import ImpressionListenerWrapper, ImpressionListenerWrapperAsync +from splitio.engine.impressions.impressions import Manager as ImpressionsManager +from splitio.engine.impressions import set_classes, set_classes_async +from splitio.engine.impressions.strategies import StrategyDebugMode, StrategyNoneMode +from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageConsumer, \ + TelemetryStorageProducerAsync, TelemetryStorageConsumerAsync +from splitio.engine.impressions.manager import Counter as ImpressionsCounter +from splitio.engine.impressions.unique_keys_tracker import UniqueKeysTracker, UniqueKeysTrackerAsync +from splitio.events.events_manager import EventsManager, EventsManagerAsync +from splitio.events.events_manager_config import EventsManagerConfig +from splitio.events.events_task import EventsTask, EventsTaskAsync +from splitio.events.events_delivery import EventsDelivery +from splitio.models.fallback_config import FallbackTreatmentCalculator +from splitio.models.notification import SdkInternalEventNotification +from splitio.models.events import SdkInternalEvent # Storage from splitio.storage.inmemmory import InMemorySplitStorage, InMemorySegmentStorage, \ - InMemoryImpressionStorage, InMemoryEventStorage + InMemoryImpressionStorage, InMemoryEventStorage, InMemoryTelemetryStorage, LocalhostTelemetryStorage, \ + InMemorySplitStorageAsync, InMemorySegmentStorageAsync, InMemoryImpressionStorageAsync, \ + InMemoryEventStorageAsync, InMemoryTelemetryStorageAsync, LocalhostTelemetryStorageAsync, \ + InMemoryRuleBasedSegmentStorage, InMemoryRuleBasedSegmentStorageAsync from splitio.storage.adapters import redis from splitio.storage.redis import RedisSplitStorage, RedisSegmentStorage, RedisImpressionsStorage, \ - RedisEventsStorage + RedisEventsStorage, RedisTelemetryStorage, RedisSplitStorageAsync, RedisEventsStorageAsync,\ + RedisSegmentStorageAsync, RedisImpressionsStorageAsync, RedisTelemetryStorageAsync, \ + RedisRuleBasedSegmentsStorage, RedisRuleBasedSegmentsStorageAsync +from splitio.storage.pluggable import PluggableEventsStorage, PluggableImpressionsStorage, PluggableSegmentStorage, \ + PluggableSplitStorage, PluggableTelemetryStorage, PluggableTelemetryStorageAsync, PluggableEventsStorageAsync, \ + PluggableImpressionsStorageAsync, PluggableSegmentStorageAsync, PluggableSplitStorageAsync, \ + PluggableRuleBasedSegmentsStorage, PluggableRuleBasedSegmentsStorageAsync # APIs -from splitio.api.client import HttpClient -from splitio.api.splits import SplitsAPI -from splitio.api.segments import SegmentsAPI -from splitio.api.impressions import ImpressionsAPI -from splitio.api.events import EventsAPI -from splitio.api.auth import AuthAPI +from splitio.api.client import HttpClient, HttpClientAsync, HttpClientKerberos +from splitio.api.splits import SplitsAPI, SplitsAPIAsync +from splitio.api.segments import SegmentsAPI, SegmentsAPIAsync +from splitio.api.impressions import ImpressionsAPI, ImpressionsAPIAsync +from splitio.api.events import EventsAPI, EventsAPIAsync +from splitio.api.auth import AuthAPI, AuthAPIAsync +from splitio.api.telemetry import TelemetryAPI, TelemetryAPIAsync +from splitio.util.time import get_current_epoch_time_ms # Tasks -from splitio.tasks.split_sync import SplitSynchronizationTask -from splitio.tasks.segment_sync import SegmentSynchronizationTask -from splitio.tasks.impressions_sync import ImpressionsSyncTask, ImpressionsCountSyncTask -from splitio.tasks.events_sync import EventsSyncTask +from splitio.tasks.split_sync import SplitSynchronizationTask, SplitSynchronizationTaskAsync +from splitio.tasks.segment_sync import SegmentSynchronizationTask, SegmentSynchronizationTaskAsync +from splitio.tasks.impressions_sync import ImpressionsSyncTask, ImpressionsCountSyncTask,\ + ImpressionsCountSyncTaskAsync, ImpressionsSyncTaskAsync +from splitio.tasks.events_sync import EventsSyncTask, EventsSyncTaskAsync +from splitio.tasks.telemetry_sync import TelemetrySyncTask, TelemetrySyncTaskAsync # Synchronizer from splitio.sync.synchronizer import SplitTasks, SplitSynchronizers, Synchronizer, \ - LocalhostSynchronizer -from splitio.sync.manager import Manager -from splitio.sync.split import SplitSynchronizer, LocalSplitSynchronizer -from splitio.sync.segment import SegmentSynchronizer -from splitio.sync.impression import ImpressionSynchronizer, ImpressionsCountSynchronizer -from splitio.sync.event import EventSynchronizer + LocalhostSynchronizer, RedisSynchronizer, PluggableSynchronizer,\ + SynchronizerAsync, RedisSynchronizerAsync, LocalhostSynchronizerAsync +from splitio.sync.manager import Manager, RedisManager, ManagerAsync, RedisManagerAsync +from splitio.sync.split import SplitSynchronizer, LocalSplitSynchronizer, LocalhostMode,\ + SplitSynchronizerAsync, LocalSplitSynchronizerAsync +from splitio.sync.segment import SegmentSynchronizer, LocalSegmentSynchronizer, SegmentSynchronizerAsync,\ + LocalSegmentSynchronizerAsync +from splitio.sync.impression import ImpressionSynchronizer, ImpressionsCountSynchronizer, \ + ImpressionsCountSynchronizerAsync, ImpressionSynchronizerAsync +from splitio.sync.event import EventSynchronizer, EventSynchronizerAsync +from splitio.sync.telemetry import TelemetrySynchronizer, InMemoryTelemetrySubmitter, \ + LocalhostTelemetrySubmitter, RedisTelemetrySubmitter, LocalhostTelemetrySubmitterAsync, \ + InMemoryTelemetrySubmitterAsync, TelemetrySynchronizerAsync, RedisTelemetrySubmitterAsync + # Recorder -from splitio.recorder.recorder import StandardRecorder, PipelinedRecorder +from splitio.recorder.recorder import StandardRecorder, PipelinedRecorder, StandardRecorderAsync, PipelinedRecorderAsync # Localhost stuff -from splitio.client.localhost import LocalhostEventsStorage, LocalhostImpressionsStorage +from splitio.client.localhost import LocalhostEventsStorage, LocalhostImpressionsStorage, \ + LocalhostImpressionsStorageAsync, LocalhostEventsStorageAsync _LOGGER = logging.getLogger(__name__) _INSTANTIATED_FACTORIES = Counter() _INSTANTIATED_FACTORIES_LOCK = threading.RLock() _MIN_DEFAULT_DATA_SAMPLING_ALLOWED = 0.1 # 10% +_MAX_RETRY_SYNC_ALL = 3 +_UNIQUE_KEYS_CACHE_SIZE = 30000 class Status(Enum): @@ -71,18 +109,79 @@ class TimeoutException(Exception): pass -class SplitFactory(object): # pylint: disable=too-many-instance-attributes +class SplitFactoryBase(object): # pylint: disable=too-many-instance-attributes + """Split Factory/Container class.""" + + def __init__(self, sdk_key, storages): + self._sdk_key = sdk_key + self._storages = storages + self._status = None + + def _get_storage(self, name): + """ + Return a reference to the specified storage. + + :param name: Name of the requested storage. + :type name: str + + :return: requested factory. + :rtype: object + """ + return self._storages[name] + + @property + def ready(self): + """ + Return whether the factory is ready. + + :return: True if the factory is ready. False otherwhise. + :rtype: bool + """ + return self._status == Status.READY + + def _update_instantiated_factories(self): + self._status = Status.DESTROYED + with _INSTANTIATED_FACTORIES_LOCK: + _INSTANTIATED_FACTORIES.subtract([self._sdk_key]) + + @property + def destroyed(self): + """ + Return whether the factory has been destroyed or not. + + :return: True if the factory has been destroyed. False otherwise. + :rtype: bool + """ + return self._status == Status.DESTROYED + + def _waiting_fork(self): + """ + Return whether the factory is waiting to be recreated by forking or not. + + :return: True if the factory is waiting to be recreated by forking. False otherwise. + :rtype: bool + """ + return self._status == Status.WAITING_FORK + + +class SplitFactory(SplitFactoryBase): # pylint: disable=too-many-instance-attributes """Split Factory/Container class.""" def __init__( # pylint: disable=too-many-arguments self, - apikey, + sdk_key, storages, labels_enabled, recorder, + internal_events_queue, + events_manager, sync_manager=None, sdk_ready_flag=None, + telemetry_producer=None, + telemetry_init_producer=None, + telemetry_submitter=None, preforked_initialization=False, + fallback_treatment_calculator=None ): """ Class constructor. @@ -102,13 +201,20 @@ def __init__( # pylint: disable=too-many-arguments :param preforked_initialization: Whether should be instantiated as preforked or not. :type preforked_initialization: bool """ - self._apikey = apikey - self._storages = storages + SplitFactoryBase.__init__(self, sdk_key, storages) self._labels_enabled = labels_enabled self._sync_manager = sync_manager - self._sdk_internal_ready_flag = sdk_ready_flag self._recorder = recorder self._preforked_initialization = preforked_initialization + self._telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + self._telemetry_init_producer = telemetry_init_producer + self._telemetry_submitter = telemetry_submitter + self._ready_time = get_current_epoch_time_ms() + _LOGGER.debug("Running in threading mode") + self._sdk_internal_ready_flag = sdk_ready_flag + self._fallback_treatment_calculator = fallback_treatment_calculator + self._internal_events_queue = internal_events_queue + self._events_manager = events_manager self._start_status_updater() def _start_status_updater(self): @@ -125,29 +231,26 @@ def _start_status_updater(self): self._status = Status.NOT_INITIALIZED # add a listener that updates the status to READY once the flag is set. ready_updater = threading.Thread(target=self._update_status_when_ready, - name='SDKReadyFlagUpdater') - ready_updater.setDaemon(True) + name='SDKReadyFlagUpdater', daemon=True) ready_updater.start() else: self._status = Status.READY - + self._internal_events_queue.put(SdkInternalEventNotification(SdkInternalEvent.SDK_READY, None)) + def _update_status_when_ready(self): """Wait until the sdk is ready and update the status.""" self._sdk_internal_ready_flag.wait() self._status = Status.READY self._sdk_ready_flag.set() + self._internal_events_queue.put(SdkInternalEventNotification(SdkInternalEvent.SDK_READY, None)) - def _get_storage(self, name): - """ - Return a reference to the specified storage. + self._telemetry_init_producer.record_ready_time(get_current_epoch_time_ms() - self._ready_time) + redundant_factory_count, active_factory_count = _get_active_and_redundant_count() + self._telemetry_init_producer.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) - :param name: Name of the requested storage. - :type name: str - - :return: requested factory. - :rtype: object - """ - return self._storages[name] + config_post_thread = threading.Thread(target=self._telemetry_submitter.synchronize_config(), name="PostConfigData") + config_post_thread.setDaemon(True) + config_post_thread.start() def client(self): """ @@ -156,7 +259,7 @@ def client(self): This client is only a set of references to structures hold by the factory. Creating one a fast operation and safe to be used anywhere. """ - return Client(self, self._recorder, self._labels_enabled) + return Client(self, self._recorder, self._events_manager, self._labels_enabled, self._fallback_treatment_calculator) def manager(self): """ @@ -180,18 +283,9 @@ def block_until_ready(self, timeout=None): ready = self._sdk_ready_flag.wait(timeout) if not ready: + self._telemetry_init_producer.record_bur_time_out() raise TimeoutException('SDK Initialization: time of %d exceeded' % timeout) - @property - def ready(self): - """ - Return whether the factory is ready. - - :return: True if the factory is ready. False otherwhise. - :rtype: bool - """ - return self._status == Status.READY - def destroy(self, destroyed_event=None): """ Destroy the factory and render clients unusable. @@ -207,6 +301,8 @@ def destroy(self, destroyed_event=None): return try: + _LOGGER.info('Factory destroy called, stopping tasks.') + self._events_manager.destroy() if self._sync_manager is not None: if destroyed_event is not None: @@ -214,36 +310,14 @@ def _wait_for_tasks_to_stop(): self._sync_manager.stop(True) destroyed_event.set() - wait_thread = threading.Thread(target=_wait_for_tasks_to_stop) - wait_thread.setDaemon(True) + wait_thread = threading.Thread(target=_wait_for_tasks_to_stop, daemon=True) wait_thread.start() else: self._sync_manager.stop(False) elif destroyed_event is not None: destroyed_event.set() finally: - self._status = Status.DESTROYED - with _INSTANTIATED_FACTORIES_LOCK: - _INSTANTIATED_FACTORIES.subtract([self._apikey]) - - @property - def destroyed(self): - """ - Return whether the factory has been destroyed or not. - - :return: True if the factory has been destroyed. False otherwise. - :rtype: bool - """ - return self._status == Status.DESTROYED - - def _waiting_fork(self): - """ - Return whether the factory is waiting to be recreated by forking or not. - - :return: True if the factory is waiting to be recreated by forking. False otherwise. - :rtype: bool - """ - return self._status == Status.WAITING_FORK + self._update_instantiated_factories() def resume(self): """ @@ -261,13 +335,158 @@ def resume(self): initialization_thread = threading.Thread( target=self._sync_manager.start, name="SDKInitializer", + daemon=True ) - initialization_thread.setDaemon(True) initialization_thread.start() self._preforked_initialization = False # reset for status updater self._start_status_updater() +class SplitFactoryAsync(SplitFactoryBase): # pylint: disable=too-many-instance-attributes + """Split Factory/Container async class.""" + + def __init__( # pylint: disable=too-many-arguments + self, + sdk_key, + storages, + labels_enabled, + recorder, + internal_events_queue, + events_manager, + sync_manager=None, + telemetry_producer=None, + telemetry_init_producer=None, + telemetry_submitter=None, + manager_start_task=None, + api_client=None, + fallback_treatment_calculator=None + ): + """ + Class constructor. + + :param storages: Dictionary of storages for all split models. + :type storages: dict + :param labels_enabled: Whether the impressions should store labels or not. + :type labels_enabled: bool + :param apis: Dictionary of apis client wrappers + :type apis: dict + :param sync_manager: Manager synchronization + :type sync_manager: splitio.sync.manager.Manager + :param sdk_ready_flag: Event to set when the sdk is ready. + :type sdk_ready_flag: threading.Event + :param recorder: StatsRecorder instance + :type recorder: StatsRecorder + :param preforked_initialization: Whether should be instantiated as preforked or not. + :type preforked_initialization: bool + """ + SplitFactoryBase.__init__(self, sdk_key, storages) + self._labels_enabled = labels_enabled + self._sync_manager = sync_manager + self._recorder = recorder + self._telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + self._telemetry_init_producer = telemetry_init_producer + self._telemetry_submitter = telemetry_submitter + self._ready_time = get_current_epoch_time_ms() + _LOGGER.debug("Running in asyncio mode") + self._internal_events_queue = internal_events_queue + self._events_manager = events_manager + self._manager_start_task = manager_start_task + self._status = Status.NOT_INITIALIZED + self._sdk_ready_flag = asyncio.Event() + self._ready_task = asyncio.get_running_loop().create_task(self._update_status_when_ready_async()) + self._api_client = api_client + self._fallback_treatment_calculator = fallback_treatment_calculator + + async def _update_status_when_ready_async(self): + """Wait until the sdk is ready and update the status for async mode.""" + if self._manager_start_task is not None: + await self._manager_start_task + self._manager_start_task = None + await self._telemetry_init_producer.record_ready_time(get_current_epoch_time_ms() - self._ready_time) + redundant_factory_count, active_factory_count = _get_active_and_redundant_count() + await self._telemetry_init_producer.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + try: + await self._telemetry_submitter.synchronize_config() + except Exception as e: + _LOGGER.error("Failed to post Telemetry config") + _LOGGER.debug(str(e)) + self._status = Status.READY + self._sdk_ready_flag.set() + await self._internal_events_queue.put(SdkInternalEventNotification(SdkInternalEvent.SDK_READY, None)) + + def manager(self): + """ + Return a new manager. + + This manager is only a set of references to structures hold by the factory. + Creating one a fast operation and safe to be used anywhere. + """ + return SplitManagerAsync(self) + + async def block_until_ready(self, timeout=None): + """ + Blocks until the sdk is ready or the timeout specified by the user expires. + + When ready, the factory's status is updated accordingly. + + :param timeout: Number of seconds to wait (fractions allowed) + :type timeout: int + """ + try: + await asyncio.wait_for(asyncio.shield(self._sdk_ready_flag.wait()), timeout) + except asyncio.TimeoutError as e: + _LOGGER.error("Exception initializing SDK") + _LOGGER.debug(str(e)) + await self._telemetry_init_producer.record_bur_time_out() + raise TimeoutException('SDK Initialization: time of %d exceeded' % timeout) + + async def destroy(self, destroyed_event=None): + """ + Destroy the factory and render clients unusable. + + Destroy frees up storage taken but split data, flushes impressions & events, + and invalidates the clients, making them return control. + + :param destroyed_event: Event to signal when destroy process has finished. + :type destroyed_event: threading.Event + """ + if self.destroyed: + _LOGGER.info('Factory already destroyed.') + return + + try: + _LOGGER.info('Factory destroy called, stopping tasks.') + if self._manager_start_task is not None and not self._manager_start_task.done(): + self._manager_start_task.cancel() + + if self._sync_manager is not None: + await self._sync_manager.stop(True) + + if not self._ready_task.done(): + self._ready_task.cancel() + self._ready_task = None + + if isinstance(self._storages['splits'], RedisSplitStorageAsync): + await self._get_storage('splits').redis.close() + + if isinstance(self._sync_manager, ManagerAsync) and isinstance(self._telemetry_submitter, InMemoryTelemetrySubmitterAsync): + await self._api_client.close_session() + + except Exception as e: + _LOGGER.error('Exception destroying factory.') + _LOGGER.debug(str(e)) + finally: + self._update_instantiated_factories() + + def client(self): + """ + Return a new client. + + This client is only a set of references to structures hold by the factory. + Creating one a fast operation and safe to be used anywhere. + """ + return ClientAsync(self, self._recorder, self._events_manager, self._labels_enabled, self._fallback_treatment_calculator) + def _wrap_impression_listener(listener, metadata): """ Wrap the impression listener if any. @@ -279,53 +498,108 @@ def _wrap_impression_listener(listener, metadata): """ if listener is not None: return ImpressionListenerWrapper(listener, metadata) + return None +def _wrap_impression_listener_async(listener, metadata): + """ + Wrap the impression listener if any. + + :param listener: User supplied impression listener or None + :type listener: splitio.client.listener.ImpressionListener | None + :param metadata: SDK Metadata + :type metadata: splitio.client.util.SdkMetadata + """ + if listener is not None: + return ImpressionListenerWrapperAsync(listener, metadata) + + return None def _build_in_memory_factory(api_key, cfg, sdk_url=None, events_url=None, # pylint:disable=too-many-arguments,too-many-locals - auth_api_base_url=None, streaming_api_base_url=None): + auth_api_base_url=None, streaming_api_base_url=None, telemetry_api_base_url=None, + total_flag_sets=0, invalid_flag_sets=0): """Build and return a split factory tailored to the supplied config.""" if not input_validator.validate_factory_instantiation(api_key): return None - http_client = HttpClient( - sdk_url=sdk_url, - events_url=events_url, - auth_url=auth_api_base_url, - timeout=cfg.get('connectionTimeout') - ) + extra_cfg = {} + extra_cfg['sdk_url'] = sdk_url + extra_cfg['events_url'] = events_url + extra_cfg['auth_url'] = auth_api_base_url + extra_cfg['streaming_url'] = streaming_api_base_url + extra_cfg['telemetry_url'] = telemetry_api_base_url + + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_consumer = TelemetryStorageConsumer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + telemetry_init_producer = telemetry_producer.get_telemetry_init_producer() + + authentication_params = None + if cfg.get("httpAuthenticateScheme") in [AuthenticateScheme.KERBEROS_SPNEGO, AuthenticateScheme.KERBEROS_PROXY]: + authentication_params = [cfg.get("kerberosPrincipalUser"), + cfg.get("kerberosPrincipalPassword")] + http_client = HttpClientKerberos( + sdk_url=sdk_url, + events_url=events_url, + auth_url=auth_api_base_url, + telemetry_url=telemetry_api_base_url, + timeout=cfg.get('connectionTimeout'), + authentication_scheme = cfg.get("httpAuthenticateScheme"), + authentication_params = authentication_params + ) + else: + http_client = HttpClient( + sdk_url=sdk_url, + events_url=events_url, + auth_url=auth_api_base_url, + telemetry_url=telemetry_api_base_url, + timeout=cfg.get('connectionTimeout'), + ) sdk_metadata = util.get_metadata(cfg) apis = { - 'auth': AuthAPI(http_client, api_key, sdk_metadata), - 'splits': SplitsAPI(http_client, api_key, sdk_metadata), - 'segments': SegmentsAPI(http_client, api_key, sdk_metadata), - 'impressions': ImpressionsAPI(http_client, api_key, sdk_metadata, cfg['impressionsMode']), - 'events': EventsAPI(http_client, api_key, sdk_metadata), + 'auth': AuthAPI(http_client, api_key, sdk_metadata, telemetry_runtime_producer), + 'splits': SplitsAPI(http_client, api_key, sdk_metadata, telemetry_runtime_producer), + 'segments': SegmentsAPI(http_client, api_key, sdk_metadata, telemetry_runtime_producer), + 'impressions': ImpressionsAPI(http_client, api_key, sdk_metadata, telemetry_runtime_producer, cfg['impressionsMode']), + 'events': EventsAPI(http_client, api_key, sdk_metadata, telemetry_runtime_producer), + 'telemetry': TelemetryAPI(http_client, api_key, sdk_metadata, telemetry_runtime_producer), } - - if not input_validator.validate_apikey_type(apis['segments']): - return None - + + internal_events_queue = queue.Queue() + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + internal_events_task = EventsTask(events_manager.notify_internal_event, internal_events_queue) storages = { - 'splits': InMemorySplitStorage(), - 'segments': InMemorySegmentStorage(), - 'impressions': InMemoryImpressionStorage(cfg['impressionsQueueSize']), - 'events': InMemoryEventStorage(cfg['eventsQueueSize']), + 'splits': InMemorySplitStorage(internal_events_queue, cfg['flagSetsFilter'] if cfg['flagSetsFilter'] is not None else []), + 'segments': InMemorySegmentStorage(internal_events_queue), + 'rule_based_segments': InMemoryRuleBasedSegmentStorage(internal_events_queue), + 'impressions': InMemoryImpressionStorage(cfg['impressionsQueueSize'], telemetry_runtime_producer), + 'events': InMemoryEventStorage(cfg['eventsQueueSize'], telemetry_runtime_producer), } + telemetry_submitter = InMemoryTelemetrySubmitter(telemetry_consumer, storages['splits'], storages['segments'], apis['telemetry']) + + imp_counter = ImpressionsCounter() + unique_keys_tracker = UniqueKeysTracker(_UNIQUE_KEYS_CACHE_SIZE) + unique_keys_synchronizer, clear_filter_sync, unique_keys_task, \ + clear_filter_task, impressions_count_sync, impressions_count_task, \ + imp_strategy, none_strategy = set_classes('MEMORY', cfg['impressionsMode'], apis, imp_counter, unique_keys_tracker) + imp_manager = ImpressionsManager( - cfg['impressionsMode'], - True, - _wrap_impression_listener(cfg['impressionListener'], sdk_metadata)) + imp_strategy, none_strategy, telemetry_runtime_producer) synchronizers = SplitSynchronizers( - SplitSynchronizer(apis['splits'], storages['splits']), - SegmentSynchronizer(apis['segments'], storages['splits'], storages['segments']), + SplitSynchronizer(apis['splits'], storages['splits'], storages['rule_based_segments']), + SegmentSynchronizer(apis['segments'], storages['splits'], storages['segments'], storages['rule_based_segments']), ImpressionSynchronizer(apis['impressions'], storages['impressions'], cfg['impressionsBulkSize']), EventSynchronizer(apis['events'], storages['events'], cfg['eventsBulkSize']), - ImpressionsCountSynchronizer(apis['impressions'], imp_manager), + impressions_count_sync, + TelemetrySynchronizer(telemetry_submitter), + unique_keys_synchronizer, + clear_filter_sync, ) tasks = SplitTasks( @@ -342,7 +616,11 @@ def _build_in_memory_factory(api_key, cfg, sdk_url=None, events_url=None, # pyl cfg['impressionsRefreshRate'], ), EventsSyncTask(synchronizers.events_sync.synchronize_events, cfg['eventsPushRate']), - ImpressionsCountSyncTask(synchronizers.impressions_count_sync.synchronize_counters) + impressions_count_task, + TelemetrySyncTask(synchronizers.telemetry_sync.synchronize_stats, cfg['metricsRefreshRate']), + unique_keys_task, + clear_filter_task, + internal_events_task ) synchronizer = Synchronizer(synchronizers, tasks) @@ -351,7 +629,7 @@ def _build_in_memory_factory(api_key, cfg, sdk_url=None, events_url=None, # pyl sdk_ready_flag = threading.Event() if not preforked_initialization else None manager = Manager(sdk_ready_flag, synchronizer, apis['auth'], cfg['streamingEnabled'], - sdk_metadata, streaming_api_base_url, api_key[-4:]) + sdk_metadata, telemetry_runtime_producer, streaming_api_base_url, api_key[-4:]) storages['events'].set_queue_full_hook(tasks.events_task.flush) storages['impressions'].set_queue_full_hook(tasks.impressions_task.flush) @@ -360,21 +638,155 @@ def _build_in_memory_factory(api_key, cfg, sdk_url=None, events_url=None, # pyl imp_manager, storages['events'], storages['impressions'], + telemetry_evaluation_producer, + telemetry_runtime_producer, + _wrap_impression_listener(cfg['impressionListener'], sdk_metadata), + imp_counter=imp_counter, + unique_keys_tracker=unique_keys_tracker ) + telemetry_init_producer.record_config(cfg, extra_cfg, total_flag_sets, invalid_flag_sets) + internal_events_task.start() + if preforked_initialization: - synchronizer.sync_all() + synchronizer.sync_all(max_retry_attempts=_MAX_RETRY_SYNC_ALL) synchronizer._split_synchronizers._segment_sync.shutdown() + return SplitFactory(api_key, storages, cfg['labelsEnabled'], - recorder, manager, preforked_initialization=preforked_initialization) + recorder, internal_events_queue, events_manager, manager, None, telemetry_producer, telemetry_init_producer, telemetry_submitter, preforked_initialization=preforked_initialization, + fallback_treatment_calculator=FallbackTreatmentCalculator(cfg['fallbackTreatments'])) - initialization_thread = threading.Thread(target=manager.start, name="SDKInitializer") - initialization_thread.setDaemon(True) + initialization_thread = threading.Thread(target=manager.start, name="SDKInitializer", daemon=True) initialization_thread.start() return SplitFactory(api_key, storages, cfg['labelsEnabled'], - recorder, manager, sdk_ready_flag) + recorder, internal_events_queue, events_manager, manager, sdk_ready_flag, + telemetry_producer, telemetry_init_producer, + telemetry_submitter, fallback_treatment_calculator = FallbackTreatmentCalculator(cfg['fallbackTreatments'])) + +async def _build_in_memory_factory_async(api_key, cfg, sdk_url=None, events_url=None, # pylint:disable=too-many-arguments,too-many-localsa + auth_api_base_url=None, streaming_api_base_url=None, telemetry_api_base_url=None, + total_flag_sets=0, invalid_flag_sets=0): + """Build and return a split factory tailored to the supplied config in async mode.""" + if not input_validator.validate_factory_instantiation(api_key): + return None + + extra_cfg = {} + extra_cfg['sdk_url'] = sdk_url + extra_cfg['events_url'] = events_url + extra_cfg['auth_url'] = auth_api_base_url + extra_cfg['streaming_url'] = streaming_api_base_url + extra_cfg['telemetry_url'] = telemetry_api_base_url + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_consumer = TelemetryStorageConsumerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + telemetry_init_producer = telemetry_producer.get_telemetry_init_producer() + + http_client = HttpClientAsync( + sdk_url=sdk_url, + events_url=events_url, + auth_url=auth_api_base_url, + telemetry_url=telemetry_api_base_url, + timeout=cfg.get('connectionTimeout') + ) + + sdk_metadata = util.get_metadata(cfg) + apis = { + 'auth': AuthAPIAsync(http_client, api_key, sdk_metadata, telemetry_runtime_producer), + 'splits': SplitsAPIAsync(http_client, api_key, sdk_metadata, telemetry_runtime_producer), + 'segments': SegmentsAPIAsync(http_client, api_key, sdk_metadata, telemetry_runtime_producer), + 'impressions': ImpressionsAPIAsync(http_client, api_key, sdk_metadata, telemetry_runtime_producer, cfg['impressionsMode']), + 'events': EventsAPIAsync(http_client, api_key, sdk_metadata, telemetry_runtime_producer), + 'telemetry': TelemetryAPIAsync(http_client, api_key, sdk_metadata, telemetry_runtime_producer), + } + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + internal_events_task = EventsTaskAsync(events_manager.notify_internal_event, internal_events_queue) + + storages = { + 'splits': InMemorySplitStorageAsync(internal_events_queue, cfg['flagSetsFilter'] if cfg['flagSetsFilter'] is not None else []), + 'segments': InMemorySegmentStorageAsync(internal_events_queue), + 'rule_based_segments': InMemoryRuleBasedSegmentStorageAsync(internal_events_queue), + 'impressions': InMemoryImpressionStorageAsync(cfg['impressionsQueueSize'], telemetry_runtime_producer), + 'events': InMemoryEventStorageAsync(cfg['eventsQueueSize'], telemetry_runtime_producer), + } + + telemetry_submitter = InMemoryTelemetrySubmitterAsync(telemetry_consumer, storages['splits'], storages['segments'], apis['telemetry']) + + imp_counter = ImpressionsCounter() + unique_keys_tracker = UniqueKeysTrackerAsync(_UNIQUE_KEYS_CACHE_SIZE) + unique_keys_synchronizer, clear_filter_sync, unique_keys_task, \ + clear_filter_task, impressions_count_sync, impressions_count_task, \ + imp_strategy, none_strategy = set_classes_async('MEMORY', cfg['impressionsMode'], apis, imp_counter, unique_keys_tracker) + + imp_manager = ImpressionsManager( + imp_strategy, none_strategy, telemetry_runtime_producer) + + synchronizers = SplitSynchronizers( + SplitSynchronizerAsync(apis['splits'], storages['splits'], storages['rule_based_segments']), + SegmentSynchronizerAsync(apis['segments'], storages['splits'], storages['segments'], storages['rule_based_segments']), + ImpressionSynchronizerAsync(apis['impressions'], storages['impressions'], + cfg['impressionsBulkSize']), + EventSynchronizerAsync(apis['events'], storages['events'], cfg['eventsBulkSize']), + impressions_count_sync, + TelemetrySynchronizerAsync(telemetry_submitter), + unique_keys_synchronizer, + clear_filter_sync, + ) + + tasks = SplitTasks( + SplitSynchronizationTaskAsync( + synchronizers.split_sync.synchronize_splits, + cfg['featuresRefreshRate'], + ), + SegmentSynchronizationTaskAsync( + synchronizers.segment_sync.synchronize_segments, + cfg['segmentsRefreshRate'], + ), + ImpressionsSyncTaskAsync( + synchronizers.impressions_sync.synchronize_impressions, + cfg['impressionsRefreshRate'], + ), + EventsSyncTaskAsync(synchronizers.events_sync.synchronize_events, cfg['eventsPushRate']), + impressions_count_task, + TelemetrySyncTaskAsync(synchronizers.telemetry_sync.synchronize_stats, cfg['metricsRefreshRate']), + unique_keys_task, + clear_filter_task, + internal_events_task + ) + synchronizer = SynchronizerAsync(synchronizers, tasks) + + manager = ManagerAsync(synchronizer, apis['auth'], cfg['streamingEnabled'], + sdk_metadata, telemetry_runtime_producer, streaming_api_base_url, api_key[-4:]) + + storages['events'].set_queue_full_hook(tasks.events_task.flush) + storages['impressions'].set_queue_full_hook(tasks.impressions_task.flush) + + recorder = StandardRecorderAsync( + imp_manager, + storages['events'], + storages['impressions'], + telemetry_evaluation_producer, + telemetry_runtime_producer, + _wrap_impression_listener_async(cfg['impressionListener'], sdk_metadata), + imp_counter=imp_counter, + unique_keys_tracker=unique_keys_tracker + ) + + await telemetry_init_producer.record_config(cfg, extra_cfg, total_flag_sets, invalid_flag_sets) + internal_events_task.start() + + manager_start_task = asyncio.get_running_loop().create_task(manager.start()) + + return SplitFactoryAsync(api_key, storages, cfg['labelsEnabled'], + recorder, internal_events_queue, events_manager, manager, + telemetry_producer, telemetry_init_producer, + telemetry_submitter, manager_start_task=manager_start_task, + api_client=http_client, fallback_treatment_calculator=FallbackTreatmentCalculator(cfg['fallbackTreatments'])) def _build_redis_factory(api_key, cfg): """Build and return a split factory with redis-based storage.""" @@ -383,110 +795,620 @@ def _build_redis_factory(api_key, cfg): cache_enabled = cfg.get('redisLocalCacheEnabled', False) cache_ttl = cfg.get('redisLocalCacheTTL', 5) storages = { - 'splits': RedisSplitStorage(redis_adapter, cache_enabled, cache_ttl), + 'splits': RedisSplitStorage(redis_adapter, cache_enabled, cache_ttl, []), 'segments': RedisSegmentStorage(redis_adapter), + 'rule_based_segments': RedisRuleBasedSegmentsStorage(redis_adapter), 'impressions': RedisImpressionsStorage(redis_adapter, sdk_metadata), 'events': RedisEventsStorage(redis_adapter, sdk_metadata), + 'telemetry': RedisTelemetryStorage(redis_adapter, sdk_metadata) } + telemetry_producer = TelemetryStorageProducer(storages['telemetry']) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_init_producer = telemetry_producer.get_telemetry_init_producer() + telemetry_submitter = RedisTelemetrySubmitter(storages['telemetry']) + data_sampling = cfg.get('dataSampling', DEFAULT_DATA_SAMPLING) if data_sampling < _MIN_DEFAULT_DATA_SAMPLING_ALLOWED: _LOGGER.warning("dataSampling cannot be less than %.2f, defaulting to minimum", _MIN_DEFAULT_DATA_SAMPLING_ALLOWED) data_sampling = _MIN_DEFAULT_DATA_SAMPLING_ALLOWED + + imp_counter = ImpressionsCounter() + unique_keys_tracker = UniqueKeysTracker(_UNIQUE_KEYS_CACHE_SIZE) + unique_keys_synchronizer, clear_filter_sync, unique_keys_task, \ + clear_filter_task, impressions_count_sync, impressions_count_task, \ + imp_strategy, none_strategy = set_classes('REDIS', cfg['impressionsMode'], redis_adapter, imp_counter, unique_keys_tracker) + + imp_manager = ImpressionsManager( + imp_strategy, none_strategy, + telemetry_runtime_producer) + + synchronizers = SplitSynchronizers(None, None, None, None, + impressions_count_sync, + None, + unique_keys_synchronizer, + clear_filter_sync + ) + + tasks = SplitTasks(None, None, None, None, + impressions_count_task, + None, + unique_keys_task, + clear_filter_task + ) + + synchronizer = RedisSynchronizer(synchronizers, tasks) recorder = PipelinedRecorder( redis_adapter.pipeline, - ImpressionsManager(cfg['impressionsMode'], False, - _wrap_impression_listener(cfg['impressionListener'], sdk_metadata)), + imp_manager, storages['events'], storages['impressions'], + storages['telemetry'], data_sampling, + _wrap_impression_listener(cfg['impressionListener'], sdk_metadata), + imp_counter=imp_counter, + unique_keys_tracker=unique_keys_tracker ) - return SplitFactory( + + manager = RedisManager(synchronizer) + initialization_thread = threading.Thread(target=manager.start, name="SDKInitializer", daemon=True) + initialization_thread.start() + + telemetry_init_producer.record_config(cfg, {}, 0, 0) + internal_events_queue = queue.Queue() + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + + split_factory = SplitFactory( + api_key, + storages, + cfg['labelsEnabled'], + recorder, + internal_events_queue, + events_manager, + manager, + sdk_ready_flag=None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_init_producer, + fallback_treatment_calculator=FallbackTreatmentCalculator(cfg['fallbackTreatments']) + ) + redundant_factory_count, active_factory_count = _get_active_and_redundant_count() + storages['telemetry'].record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + telemetry_submitter.synchronize_config() + + return split_factory + +async def _build_redis_factory_async(api_key, cfg): + """Build and return a split factory with redis-based storage.""" + sdk_metadata = util.get_metadata(cfg) + redis_adapter = await redis.build_async(cfg) + cache_enabled = cfg.get('redisLocalCacheEnabled', False) + cache_ttl = cfg.get('redisLocalCacheTTL', 5) + storages = { + 'splits': RedisSplitStorageAsync(redis_adapter, cache_enabled, cache_ttl), + 'segments': RedisSegmentStorageAsync(redis_adapter), + 'rule_based_segments': RedisRuleBasedSegmentsStorageAsync(redis_adapter), + 'impressions': RedisImpressionsStorageAsync(redis_adapter, sdk_metadata), + 'events': RedisEventsStorageAsync(redis_adapter, sdk_metadata), + 'telemetry': await RedisTelemetryStorageAsync.create(redis_adapter, sdk_metadata) + } + telemetry_producer = TelemetryStorageProducerAsync(storages['telemetry']) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_init_producer = telemetry_producer.get_telemetry_init_producer() + telemetry_submitter = RedisTelemetrySubmitterAsync(storages['telemetry']) + + data_sampling = cfg.get('dataSampling', DEFAULT_DATA_SAMPLING) + if data_sampling < _MIN_DEFAULT_DATA_SAMPLING_ALLOWED: + _LOGGER.warning("dataSampling cannot be less than %.2f, defaulting to minimum", + _MIN_DEFAULT_DATA_SAMPLING_ALLOWED) + data_sampling = _MIN_DEFAULT_DATA_SAMPLING_ALLOWED + + imp_counter = ImpressionsCounter() + unique_keys_tracker = UniqueKeysTrackerAsync(_UNIQUE_KEYS_CACHE_SIZE) + unique_keys_synchronizer, clear_filter_sync, unique_keys_task, \ + clear_filter_task, impressions_count_sync, impressions_count_task, \ + imp_strategy, none_strategy = set_classes_async('REDIS', cfg['impressionsMode'], redis_adapter, imp_counter, unique_keys_tracker) + + imp_manager = ImpressionsManager( + imp_strategy, none_strategy, + telemetry_runtime_producer) + + synchronizers = SplitSynchronizers(None, None, None, None, + impressions_count_sync, + None, + unique_keys_synchronizer, + clear_filter_sync + ) + + tasks = SplitTasks(None, None, None, None, + impressions_count_task, + None, + unique_keys_task, + clear_filter_task + ) + + synchronizer = RedisSynchronizerAsync(synchronizers, tasks) + recorder = PipelinedRecorderAsync( + redis_adapter.pipeline, + imp_manager, + storages['events'], + storages['impressions'], + storages['telemetry'], + data_sampling, + _wrap_impression_listener_async(cfg['impressionListener'], sdk_metadata), + imp_counter=imp_counter, + unique_keys_tracker=unique_keys_tracker + ) + + manager = RedisManagerAsync(synchronizer) + await telemetry_init_producer.record_config(cfg, {}, 0, 0) + manager.start() + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + split_factory = SplitFactoryAsync( + api_key, + storages, + cfg['labelsEnabled'], + recorder, + internal_events_queue, + events_manager, + manager, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_init_producer, + telemetry_submitter=telemetry_submitter, + fallback_treatment_calculator=FallbackTreatmentCalculator(cfg['fallbackTreatments']) + ) + redundant_factory_count, active_factory_count = _get_active_and_redundant_count() + await storages['telemetry'].record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + await telemetry_submitter.synchronize_config() + + return split_factory + +def _build_pluggable_factory(api_key, cfg): + """Build and return a split factory with pluggable storage.""" + sdk_metadata = util.get_metadata(cfg) + if not input_validator.validate_pluggable_adapter(cfg): + raise Exception("Pluggable Adapter validation failed, exiting") + + pluggable_adapter = cfg.get('storageWrapper') + storage_prefix = cfg.get('storagePrefix') + storages = { + 'splits': PluggableSplitStorage(pluggable_adapter, storage_prefix, []), + 'segments': PluggableSegmentStorage(pluggable_adapter, storage_prefix), + 'rule_based_segments': PluggableRuleBasedSegmentsStorage(pluggable_adapter, storage_prefix), + 'impressions': PluggableImpressionsStorage(pluggable_adapter, sdk_metadata, storage_prefix), + 'events': PluggableEventsStorage(pluggable_adapter, sdk_metadata, storage_prefix), + 'telemetry': PluggableTelemetryStorage(pluggable_adapter, sdk_metadata, storage_prefix) + } + telemetry_producer = TelemetryStorageProducer(storages['telemetry']) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_init_producer = telemetry_producer.get_telemetry_init_producer() + # Using same class as redis + telemetry_submitter = RedisTelemetrySubmitter(storages['telemetry']) + + imp_counter = ImpressionsCounter() + unique_keys_tracker = UniqueKeysTracker(_UNIQUE_KEYS_CACHE_SIZE) + unique_keys_synchronizer, clear_filter_sync, unique_keys_task, \ + clear_filter_task, impressions_count_sync, impressions_count_task, \ + imp_strategy, none_strategy = set_classes('PLUGGABLE', cfg['impressionsMode'], pluggable_adapter, imp_counter, unique_keys_tracker, storage_prefix) + + imp_manager = ImpressionsManager( + imp_strategy, none_strategy, + telemetry_runtime_producer) + + synchronizers = SplitSynchronizers(None, None, None, None, + impressions_count_sync, + None, + unique_keys_synchronizer, + clear_filter_sync + ) + + tasks = SplitTasks(None, None, None, None, + impressions_count_task, + None, + unique_keys_task, + clear_filter_task + ) + + # Using same class as redis for consumer mode only + synchronizer = RedisSynchronizer(synchronizers, tasks) + recorder = StandardRecorder( + imp_manager, + storages['events'], + storages['impressions'], + telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_runtime_producer, + _wrap_impression_listener(cfg['impressionListener'], sdk_metadata), + imp_counter=imp_counter, + unique_keys_tracker=unique_keys_tracker + ) + + # Using same class as redis for consumer mode only + manager = RedisManager(synchronizer) + initialization_thread = threading.Thread(target=manager.start, name="SDKInitializer", daemon=True) + initialization_thread.start() + + telemetry_init_producer.record_config(cfg, {}, 0, 0) + internal_events_queue = queue.Queue() + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + + split_factory = SplitFactory( + api_key, + storages, + cfg['labelsEnabled'], + recorder, + internal_events_queue, + events_manager, + manager, + sdk_ready_flag=None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_init_producer, + fallback_treatment_calculator=FallbackTreatmentCalculator(cfg['fallbackTreatments']) + ) + redundant_factory_count, active_factory_count = _get_active_and_redundant_count() + storages['telemetry'].record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + telemetry_submitter.synchronize_config() + + return split_factory + +async def _build_pluggable_factory_async(api_key, cfg): + """Build and return a split factory with pluggable storage.""" + sdk_metadata = util.get_metadata(cfg) + if not input_validator.validate_pluggable_adapter(cfg): + raise Exception("Pluggable Adapter validation failed, exiting") + + pluggable_adapter = cfg.get('storageWrapper') + storage_prefix = cfg.get('storagePrefix') + storages = { + 'splits': PluggableSplitStorageAsync(pluggable_adapter, storage_prefix), + 'segments': PluggableSegmentStorageAsync(pluggable_adapter, storage_prefix), + 'rule_based_segments': PluggableRuleBasedSegmentsStorageAsync(pluggable_adapter, storage_prefix), + 'impressions': PluggableImpressionsStorageAsync(pluggable_adapter, sdk_metadata, storage_prefix), + 'events': PluggableEventsStorageAsync(pluggable_adapter, sdk_metadata, storage_prefix), + 'telemetry': await PluggableTelemetryStorageAsync.create(pluggable_adapter, sdk_metadata, storage_prefix) + } + telemetry_producer = TelemetryStorageProducerAsync(storages['telemetry']) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_init_producer = telemetry_producer.get_telemetry_init_producer() + # Using same class as redis + telemetry_submitter = RedisTelemetrySubmitterAsync(storages['telemetry']) + + imp_counter = ImpressionsCounter() + unique_keys_tracker = UniqueKeysTrackerAsync(_UNIQUE_KEYS_CACHE_SIZE) + unique_keys_synchronizer, clear_filter_sync, unique_keys_task, \ + clear_filter_task, impressions_count_sync, impressions_count_task, \ + imp_strategy, none_strategy = set_classes_async('PLUGGABLE', cfg['impressionsMode'], pluggable_adapter, imp_counter, unique_keys_tracker, storage_prefix) + + imp_manager = ImpressionsManager( + imp_strategy, none_strategy, + telemetry_runtime_producer) + + synchronizers = SplitSynchronizers(None, None, None, None, + impressions_count_sync, + None, + unique_keys_synchronizer, + clear_filter_sync + ) + + tasks = SplitTasks(None, None, None, None, + impressions_count_task, + None, + unique_keys_task, + clear_filter_task + ) + + # Using same class as redis for consumer mode only + synchronizer = RedisSynchronizerAsync(synchronizers, tasks) + recorder = StandardRecorderAsync( + imp_manager, + storages['events'], + storages['impressions'], + telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_runtime_producer, + _wrap_impression_listener_async(cfg['impressionListener'], sdk_metadata), + imp_counter=imp_counter, + unique_keys_tracker=unique_keys_tracker + ) + + # Using same class as redis for consumer mode only + manager = RedisManagerAsync(synchronizer) + manager.start() + await telemetry_init_producer.record_config(cfg, {}, 0, 0) + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + split_factory = SplitFactoryAsync( api_key, storages, cfg['labelsEnabled'], recorder, + internal_events_queue, + events_manager, + manager, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_init_producer, + telemetry_submitter=telemetry_submitter, + fallback_treatment_calculator=FallbackTreatmentCalculator(cfg['fallbackTreatments']) ) + redundant_factory_count, active_factory_count = _get_active_and_redundant_count() + await storages['telemetry'].record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + await telemetry_submitter.synchronize_config() + return split_factory def _build_localhost_factory(cfg): """Build and return a localhost factory for testing/development purposes.""" + telemetry_storage = LocalhostTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + internal_events_queue = queue.Queue() storages = { - 'splits': InMemorySplitStorage(), - 'segments': InMemorySegmentStorage(), # not used, just to avoid possible future errors. + 'splits': InMemorySplitStorage(internal_events_queue, cfg['flagSetsFilter'] if cfg['flagSetsFilter'] is not None else []), + 'segments': InMemorySegmentStorage(internal_events_queue), # not used, just to avoid possible future errors. + 'rule_based_segments': InMemoryRuleBasedSegmentStorage(internal_events_queue), 'impressions': LocalhostImpressionsStorage(), 'events': LocalhostEventsStorage(), } - + localhost_mode = LocalhostMode.JSON if cfg['splitFile'][-5:].lower() == '.json' else LocalhostMode.LEGACY synchronizers = SplitSynchronizers( - LocalSplitSynchronizer(cfg['splitFile'], storages['splits']), - None, None, None, None, + LocalSplitSynchronizer(cfg['splitFile'], + storages['splits'], + storages['rule_based_segments'], + localhost_mode), + LocalSegmentSynchronizer(cfg['segmentDirectory'], storages['splits'], storages['segments']), + None, None, None, ) + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + internal_events_task = EventsTask(events_manager.notify_internal_event, internal_events_queue) - tasks = SplitTasks( - SplitSynchronizationTask( + feature_flag_sync_task = None + segment_sync_task = None + if cfg['localhostRefreshEnabled'] and localhost_mode == LocalhostMode.JSON: + feature_flag_sync_task = SplitSynchronizationTask( synchronizers.split_sync.synchronize_splits, cfg['featuresRefreshRate'], - ), None, None, None, None, + ) + segment_sync_task = SegmentSynchronizationTask( + synchronizers.segment_sync.synchronize_segments, + cfg['segmentsRefreshRate'], + ) + tasks = SplitTasks( + feature_flag_sync_task, + segment_sync_task, + None, None, None, + internal_events_task=internal_events_task ) sdk_metadata = util.get_metadata(cfg) ready_event = threading.Event() - synchronizer = LocalhostSynchronizer(synchronizers, tasks) - manager = Manager(ready_event, synchronizer, None, False, sdk_metadata) - manager.start() + synchronizer = LocalhostSynchronizer(synchronizers, tasks, localhost_mode) + manager = Manager(ready_event, synchronizer, None, False, sdk_metadata, telemetry_runtime_producer) + +# TODO: BUR is only applied for Localhost JSON mode, in future legacy and yaml will also use BUR + if localhost_mode == LocalhostMode.JSON: + initialization_thread = threading.Thread(target=manager.start, name="SDKInitializer", daemon=True) + initialization_thread.start() + else: + manager.start() + recorder = StandardRecorder( - ImpressionsManager(cfg['impressionsMode'], True, None), + ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer), storages['events'], storages['impressions'], + telemetry_evaluation_producer, + telemetry_runtime_producer ) + internal_events_task.start() + return SplitFactory( 'localhost', storages, False, recorder, + internal_events_queue, + events_manager, manager, - ready_event + ready_event, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + telemetry_submitter=LocalhostTelemetrySubmitter(), + fallback_treatment_calculator=FallbackTreatmentCalculator(cfg['fallbackTreatments']) + ) + +async def _build_localhost_factory_async(cfg): + """Build and return a localhost async factory for testing/development purposes.""" + telemetry_storage = LocalhostTelemetryStorageAsync() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + internal_events_task = EventsTaskAsync(events_manager.notify_internal_event, internal_events_queue) + + storages = { + 'splits': InMemorySplitStorageAsync(internal_events_queue), + 'segments': InMemorySegmentStorageAsync(internal_events_queue), # not used, just to avoid possible future errors. + 'rule_based_segments': InMemoryRuleBasedSegmentStorageAsync(internal_events_queue), + 'impressions': LocalhostImpressionsStorageAsync(), + 'events': LocalhostEventsStorageAsync(), + } + localhost_mode = LocalhostMode.JSON if cfg['splitFile'][-5:].lower() == '.json' else LocalhostMode.LEGACY + synchronizers = SplitSynchronizers( + LocalSplitSynchronizerAsync(cfg['splitFile'], + storages['splits'], + storages['rule_based_segments'], + localhost_mode), + LocalSegmentSynchronizerAsync(cfg['segmentDirectory'], storages['splits'], storages['segments']), + None, None, None, + ) + + feature_flag_sync_task = None + segment_sync_task = None + if cfg['localhostRefreshEnabled'] and localhost_mode == LocalhostMode.JSON: + feature_flag_sync_task = SplitSynchronizationTaskAsync( + synchronizers.split_sync.synchronize_splits, + cfg['featuresRefreshRate'], + ) + segment_sync_task = SegmentSynchronizationTaskAsync( + synchronizers.segment_sync.synchronize_segments, + cfg['segmentsRefreshRate'], + ) + tasks = SplitTasks( + feature_flag_sync_task, + segment_sync_task, + None, None, None, + internal_events_task=internal_events_task ) + sdk_metadata = util.get_metadata(cfg) + synchronizer = LocalhostSynchronizerAsync(synchronizers, tasks, localhost_mode) + manager = ManagerAsync(synchronizer, None, False, sdk_metadata, telemetry_runtime_producer) + +# TODO: BUR is only applied for Localhost JSON mode, in future legacy and yaml will also use BUR + manager_start_task = None + if localhost_mode == LocalhostMode.JSON: + manager_start_task = asyncio.get_running_loop().create_task(manager.start()) + else: + await manager.start() + + recorder = StandardRecorderAsync( + ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer), + storages['events'], + storages['impressions'], + telemetry_evaluation_producer, + telemetry_runtime_producer + ) + internal_events_task.start() + + return SplitFactoryAsync( + 'localhost', + storages, + False, + recorder, + internal_events_queue, + events_manager, + manager, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + telemetry_submitter=LocalhostTelemetrySubmitterAsync(), + manager_start_task=manager_start_task, + fallback_treatment_calculator=FallbackTreatmentCalculator(cfg['fallbackTreatments']) + ) def get_factory(api_key, **kwargs): """Build and return the appropriate factory.""" - try: - _INSTANTIATED_FACTORIES_LOCK.acquire() - if _INSTANTIATED_FACTORIES: - if api_key in _INSTANTIATED_FACTORIES: + _INSTANTIATED_FACTORIES_LOCK.acquire() + if _INSTANTIATED_FACTORIES: + if api_key in _INSTANTIATED_FACTORIES: + if _INSTANTIATED_FACTORIES[api_key] > 0: _LOGGER.warning( - "factory instantiation: You already have %d %s with this API Key. " + "factory instantiation: You already have %d %s with this SDK Key. " "We recommend keeping only one instance of the factory at all times " "(Singleton pattern) and reusing it throughout your application.", _INSTANTIATED_FACTORIES[api_key], 'factory' if _INSTANTIATED_FACTORIES[api_key] == 1 else 'factories' ) - else: + else: + _LOGGER.warning( + "factory instantiation: You already have an instance of the Split factory. " + "Make sure you definitely want this additional instance. " + "We recommend keeping only one instance of the factory at all times " + "(Singleton pattern) and reusing it throughout your application." + ) + + _INSTANTIATED_FACTORIES.update([api_key]) + _INSTANTIATED_FACTORIES_LOCK.release() + + config_raw = kwargs.get('config', {}) + total_flag_sets, invalid_flag_sets = _get_total_and_invalid_flag_sets(config_raw) + + config = sanitize_config(api_key, config_raw) + + if config['operationMode'] == 'localhost': + split_factory = _build_localhost_factory(config) + elif config['storageType'] == 'redis': + split_factory = _build_redis_factory(api_key, config) + elif config['storageType'] == 'pluggable': + split_factory = _build_pluggable_factory(api_key, config) + else: + split_factory = _build_in_memory_factory( + api_key, + config, + kwargs.get('sdk_api_base_url'), + kwargs.get('events_api_base_url'), + kwargs.get('auth_api_base_url'), + kwargs.get('streaming_api_base_url'), + kwargs.get('telemetry_api_base_url'), + total_flag_sets, + invalid_flag_sets) + + return split_factory + +async def get_factory_async(api_key, **kwargs): + """Build and return the appropriate factory.""" + _INSTANTIATED_FACTORIES_LOCK.acquire() + if _INSTANTIATED_FACTORIES: + if api_key in _INSTANTIATED_FACTORIES: + if _INSTANTIATED_FACTORIES[api_key] > 0: _LOGGER.warning( - "factory instantiation: You already have an instance of the Split factory. " - "Make sure you definitely want this additional instance. " + "factory instantiation: You already have %d %s with this SDK Key. " "We recommend keeping only one instance of the factory at all times " - "(Singleton pattern) and reusing it throughout your application." + "(Singleton pattern) and reusing it throughout your application.", + _INSTANTIATED_FACTORIES[api_key], + 'factory' if _INSTANTIATED_FACTORIES[api_key] == 1 else 'factories' ) - - config = sanitize_config(api_key, kwargs.get('config', {})) - - if config['operationMode'] == 'localhost-standalone': - return _build_localhost_factory(config) - - if config['operationMode'] == 'redis-consumer': - return _build_redis_factory(api_key, config) - - return _build_in_memory_factory( - api_key, - config, - kwargs.get('sdk_api_base_url'), - kwargs.get('events_api_base_url'), - kwargs.get('auth_api_base_url'), - kwargs.get('streaming_api_base_url') - ) - finally: - _INSTANTIATED_FACTORIES.update([api_key]) - _INSTANTIATED_FACTORIES_LOCK.release() + else: + _LOGGER.warning( + "factory instantiation: You already have an instance of the Split factory. " + "Make sure you definitely want this additional instance. " + "We recommend keeping only one instance of the factory at all times " + "(Singleton pattern) and reusing it throughout your application." + ) + + _INSTANTIATED_FACTORIES.update([api_key]) + _INSTANTIATED_FACTORIES_LOCK.release() + + config_raw = kwargs.get('config', {}) + total_flag_sets, invalid_flag_sets = _get_total_and_invalid_flag_sets(config_raw) + + config = sanitize_config(api_key, config_raw) + if config['operationMode'] == 'localhost': + split_factory = await _build_localhost_factory_async(config) + elif config['storageType'] == 'redis': + split_factory = await _build_redis_factory_async(api_key, config) + elif config['storageType'] == 'pluggable': + split_factory = await _build_pluggable_factory_async(api_key, config) + else: + split_factory = await _build_in_memory_factory_async( + api_key, + config, + kwargs.get('sdk_api_base_url'), + kwargs.get('events_api_base_url'), + kwargs.get('auth_api_base_url'), + kwargs.get('streaming_api_base_url'), + kwargs.get('telemetry_api_base_url'), + total_flag_sets, + invalid_flag_sets) + return split_factory + +def _get_active_and_redundant_count(): + redundant_factory_count = 0 + active_factory_count = 0 + _INSTANTIATED_FACTORIES_LOCK.acquire() + for item in _INSTANTIATED_FACTORIES: + redundant_factory_count += _INSTANTIATED_FACTORIES[item] - 1 + active_factory_count += _INSTANTIATED_FACTORIES[item] + _INSTANTIATED_FACTORIES_LOCK.release() + return redundant_factory_count, active_factory_count + +def _get_total_and_invalid_flag_sets(config_raw): + total_flag_sets = 0 + invalid_flag_sets = 0 + if config_raw.get('flagSetsFilter') is not None and isinstance(config_raw.get('flagSetsFilter'), list): + total_flag_sets = len(config_raw.get('flagSetsFilter')) + invalid_flag_sets = total_flag_sets - len(input_validator.validate_flag_sets(config_raw.get('flagSetsFilter'), 'Telemetry Init')) + + return total_flag_sets, invalid_flag_sets \ No newline at end of file diff --git a/splitio/client/input_validator.py b/splitio/client/input_validator.py index 2ca42e1f..dfded942 100644 --- a/splitio/client/input_validator.py +++ b/splitio/client/input_validator.py @@ -3,18 +3,21 @@ import logging import re import math +import inspect -from splitio.api import APIException -from splitio.api.commons import FetchOptions from splitio.client.key import Key +from splitio.client import client from splitio.engine.evaluator import CONTROL +from splitio.models.fallback_treatment import FallbackTreatment _LOGGER = logging.getLogger(__name__) MAX_LENGTH = 250 EVENT_TYPE_PATTERN = r'^[a-zA-Z0-9][-_.:a-zA-Z0-9]{0,79}$' MAX_PROPERTIES_LENGTH_BYTES = 32768 - +_FLAG_SETS_REGEX = '^[a-z0-9][_a-z0-9]{0,49}$' +_FALLBACK_TREATMENT_REGEX = '^[0-9]+[.a-zA-Z0-9_-]*$|^[a-zA-Z]+[a-zA-Z0-9_-]*$' +_FALLBACK_TREATMENT_SIZE = 100 def _check_not_null(value, name, operation): """ @@ -33,6 +36,7 @@ def _check_not_null(value, name, operation): _LOGGER.error('%s: you passed a null %s, %s must be a non-empty string.', operation, name, name) return False + return True @@ -55,6 +59,7 @@ def _check_is_string(value, name, operation): operation, name, name ) return False + return True @@ -75,10 +80,11 @@ def _check_string_not_empty(value, name, operation): _LOGGER.error('%s: you passed an empty %s, %s must be a non-empty string.', operation, name, name) return False + return True -def _check_string_matches(value, operation, pattern): +def _check_string_matches(value, operation, pattern, name, length): """ Check if value is adhere to a regular expression passed. @@ -91,16 +97,17 @@ def _check_string_matches(value, operation, pattern): :return: The result of validation :rtype: True|False """ - if not re.match(pattern, value): + if re.search(pattern, value) is None or re.search(pattern, value).group() != value: _LOGGER.error( - '%s: you passed %s, event_type must ' + + '%s: you passed %s, %s must ' + 'adhere to the regular expression %s. ' + - 'This means an event name must be alphanumeric, cannot be more ' + - 'than 80 characters long, and can only include a dash, underscore, ' + + 'This means %s must be alphanumeric, cannot be more ' + + 'than %s characters long, and can only include a dash, underscore, ' + 'period, or colon as separators of alphanumeric characters.', - operation, value, pattern + operation, value, name, pattern, name, length ) return False + return True @@ -119,6 +126,7 @@ def _check_can_convert(value, name, operation): """ if isinstance(value, str): return value + else: # check whether if isnan and isinf are really necessary if isinstance(value, bool) or (not isinstance(value, Number)) or math.isnan(value) \ @@ -126,6 +134,7 @@ def _check_can_convert(value, name, operation): _LOGGER.error('%s: you passed an invalid %s, %s must be a non-empty string.', operation, name, name) return None + _LOGGER.warning('%s: %s %s is not of type string, converting.', operation, name, value) return str(value) @@ -148,6 +157,7 @@ def _check_valid_length(value, name, operation): _LOGGER.error('%s: %s too long - must be %s characters or less.', operation, name, MAX_LENGTH) return False + return True @@ -164,21 +174,21 @@ def _check_valid_object_key(key, name, operation): :return: The result of validation :rtype: str|None """ - if key is None: - _LOGGER.error( - '%s: you passed a null %s, %s must be a non-empty string.', - operation, name, name) + if not _check_not_null(key, name, operation): return None + if isinstance(key, str): if not _check_string_not_empty(key, name, operation): return None + key_str = _check_can_convert(key, name, operation) if key_str is None or not _check_valid_length(key_str, name, operation): return None + return key_str -def _remove_empty_spaces(value, operation): +def _remove_empty_spaces(value, name, operation): """ Check if an string has whitespaces. @@ -191,9 +201,15 @@ def _remove_empty_spaces(value, operation): """ strip_value = value.strip() if value != strip_value: - _LOGGER.warning("%s: feature_name '%s' has extra whitespace, trimming.", operation, value) + _LOGGER.warning("%s: %s '%s' has extra whitespace, trimming.", operation, name, value) return strip_value +def _convert_str_to_lower(value, name, operation): + lower_value = value.lower() + if value != lower_value: + _LOGGER.warning("%s: %s '%s' should be all lowercase - converting string to lowercase", operation, name, value) + return lower_value + def validate_key(key, method_name): """ @@ -210,18 +226,19 @@ def validate_key(key, method_name): """ matching_key_result = None bucketing_key_result = None - if key is None: - _LOGGER.error('%s: you passed a null key, key must be a non-empty string.', method_name) + if not _check_not_null(key, 'key', method_name): return None, None if isinstance(key, Key): matching_key_result = _check_valid_object_key(key.matching_key, 'matching_key', method_name) if matching_key_result is None: return None, None + bucketing_key_result = _check_valid_object_key(key.bucketing_key, 'bucketing_key', method_name) if bucketing_key_result is None: return None, None + else: key_str = _check_can_convert(key, 'key', method_name) if key_str is not None and \ @@ -231,31 +248,28 @@ def validate_key(key, method_name): return matching_key_result, bucketing_key_result -def validate_feature_name(feature_name, should_validate_existance, split_storage, method_name): +def _validate_feature_flag_name(feature_flag_name, method_name): + if (not _check_not_null(feature_flag_name, 'feature_flag_name', method_name)) or \ + (not _check_is_string(feature_flag_name, 'feature_flag_name', method_name)) or \ + (not _check_string_not_empty(feature_flag_name, 'feature_flag_name', method_name)): + return False + + return True + + +def validate_feature_flag_name(feature_flag_name, method_name): """ - Check if feature_name is valid for get_treatment. + Check if feature flag name is valid for get_treatment. - :param feature_name: feature_name to be checked - :type feature_name: str - :return: feature_name + :param feature_flag_name: feature flag name to be checked + :type feature_flag_name: str + :return: feature_flag_name :rtype: str|None """ - if (not _check_not_null(feature_name, 'feature_name', method_name)) or \ - (not _check_is_string(feature_name, 'feature_name', method_name)) or \ - (not _check_string_not_empty(feature_name, 'feature_name', method_name)): + if not _validate_feature_flag_name(feature_flag_name, method_name): return None - if should_validate_existance and split_storage.get(feature_name) is None: - _LOGGER.warning( - "%s: you passed \"%s\" that does not exist in this environment, " - "please double check what Splits exist in the web console.", - method_name, - feature_name - ) - return None - - return _remove_empty_spaces(feature_name, method_name) - + return _remove_empty_spaces(feature_flag_name, 'feature flag name', method_name) def validate_track_key(key): """ @@ -268,41 +282,76 @@ def validate_track_key(key): """ if not _check_not_null(key, 'key', 'track'): return None + key_str = _check_can_convert(key, 'key', 'track') if key_str is None or \ (not _check_string_not_empty(key_str, 'key', 'track')) or \ (not _check_valid_length(key_str, 'key', 'track')): return None + return key_str -def validate_traffic_type(traffic_type, should_validate_existance, split_storage): +def _validate_traffic_type_value(traffic_type): + if (not _check_not_null(traffic_type, 'traffic_type', 'track')) or \ + (not _check_is_string(traffic_type, 'traffic_type', 'track')) or \ + (not _check_string_not_empty(traffic_type, 'traffic_type', 'track')): + return False + + return True + +def validate_traffic_type(traffic_type, should_validate_existance, feature_flag_storage): """ Check if traffic_type is valid for track. :param traffic_type: traffic_type to be checked :type traffic_type: str - :param should_validate_existance: Whether to check for existante in the split storage. + :param should_validate_existance: Whether to check for existante in the feature flag storage. :type should_validate_existance: bool - :param split_storage: Split storage. - :param split_storage: splitio.storages.SplitStorage + :param feature_flag_storage: Feature flag storage. + :param feature_flag_storage: splitio.storages.SplitStorage :return: traffic_type :rtype: str|None """ - if (not _check_not_null(traffic_type, 'traffic_type', 'track')) or \ - (not _check_is_string(traffic_type, 'traffic_type', 'track')) or \ - (not _check_string_not_empty(traffic_type, 'traffic_type', 'track')): + if not _validate_traffic_type_value(traffic_type): return None - if not traffic_type.islower(): - _LOGGER.warning('track: %s should be all lowercase - converting string to lowercase.', - traffic_type) - traffic_type = traffic_type.lower() - if should_validate_existance and not split_storage.is_valid_traffic_type(traffic_type): + traffic_type = _convert_str_to_lower(traffic_type, 'traffic type', 'track') + + if should_validate_existance and not feature_flag_storage.is_valid_traffic_type(traffic_type): _LOGGER.warning( - 'track: Traffic Type %s does not have any corresponding Splits in this environment, ' + 'track: Traffic Type %s does not have any corresponding Feature flags in this environment, ' 'make sure you\'re tracking your events to a valid traffic type defined ' - 'in the Split console.', + 'in the Split user interface.', + traffic_type + ) + + return traffic_type + + +async def validate_traffic_type_async(traffic_type, should_validate_existance, feature_flag_storage): + """ + Check if traffic_type is valid for track. + + :param traffic_type: traffic_type to be checked + :type traffic_type: str + :param should_validate_existance: Whether to check for existante in the feature flag storage. + :type should_validate_existance: bool + :param feature_flag_storage: Feature flag storage. + :param feature_flag_storage: splitio.storages.SplitStorage + :return: traffic_type + :rtype: str|None + """ + if not _validate_traffic_type_value(traffic_type): + return None + + traffic_type = _convert_str_to_lower(traffic_type, 'traffic type', 'track') + + if should_validate_existance and not await feature_flag_storage.is_valid_traffic_type(traffic_type): + _LOGGER.warning( + 'track: Traffic Type %s does not have any corresponding Feature flags in this environment, ' + 'make sure you\'re tracking your events to a valid traffic type defined ' + 'in the Split user interface.', traffic_type ) @@ -321,8 +370,9 @@ def validate_event_type(event_type): if (not _check_not_null(event_type, 'event_type', 'track')) or \ (not _check_is_string(event_type, 'event_type', 'track')) or \ (not _check_string_not_empty(event_type, 'event_type', 'track')) or \ - (not _check_string_matches(event_type, 'track', EVENT_TYPE_PATTERN)): + (not _check_string_matches(event_type, 'track', EVENT_TYPE_PATTERN, 'an event name', 80)): return None + return event_type @@ -337,91 +387,142 @@ def validate_value(value): """ if value is None: return None + if (not isinstance(value, Number)) or isinstance(value, bool): _LOGGER.error('track: value must be a number.') return False + return value +def validate_manager_feature_flag_name(feature_flag_name, should_validate_existance, feature_flag_storage): + """ + Check if feature flag name is valid for track. -def validate_manager_feature_name(feature_name, should_validate_existance, split_storage): + :param feature_flag_name: feature flag name to be checked + :type feature_flag_name: str + :return: feature_flag_name + :rtype: str|None """ - Check if feature_name is valid for track. + if not _validate_feature_flag_name(feature_flag_name, 'split'): + return None - :param feature_name: feature_name to be checked - :type feature_name: str - :return: feature_name + feature_flag = feature_flag_storage.get(feature_flag_name) + if should_validate_existance and feature_flag is None: + _LOGGER.warning( + "split: you passed \"%s\" that does not exist in this environment, " + "please double check what Feature flags exist in the Split user interface.", + feature_flag_name + ) + return None + + return feature_flag + +async def validate_manager_feature_flag_name_async(feature_flag_name, should_validate_existance, feature_flag_storage): + """ + Check if feature flag name is valid for track. + + :param feature_flag_name: feature flag name to be checked + :type feature_flag_name: str + :return: feature_flag_name :rtype: str|None """ - if (not _check_not_null(feature_name, 'feature_name', 'split')) or \ - (not _check_is_string(feature_name, 'feature_name', 'split')) or \ - (not _check_string_not_empty(feature_name, 'feature_name', 'split')): + if not _validate_feature_flag_name(feature_flag_name, 'split'): return None - if should_validate_existance and split_storage.get(feature_name) is None: + feature_flag = await feature_flag_storage.get(feature_flag_name) + if should_validate_existance and feature_flag is None: _LOGGER.warning( "split: you passed \"%s\" that does not exist in this environment, " - "please double check what Splits exist in the web console.", - feature_name + "please double check what Feature flags exist in the Split user interface.", + feature_flag_name ) return None - return feature_name + return feature_flag +def validate_feature_flag_names(feature_flags, method_name): + """ + Check if feature flag name is valid for track. -def validate_features_get_treatments( # pylint: disable=invalid-name + :param feature_flag_name: feature flag name to be checked + :type feature_flag_name: str + """ + for feature_flag in feature_flags.keys(): + if feature_flags[feature_flag] is None: + _LOGGER.warning( + "%s: you passed \"%s\" that does not exist in this environment, " + "please double check what Feature flags exist in the Split user interface.", + method_name, feature_flag + ) + +def _check_feature_flag_instance(feature_flags, method_name): + if feature_flags is None or not isinstance(feature_flags, list): + _LOGGER.error("%s: feature flag names must be a non-empty array.", method_name) + return False + + if not feature_flags: + _LOGGER.error("%s: feature flag names must be a non-empty array.", method_name) + return False + + return True + + +def _get_filtered_feature_flag(feature_flags, method_name): + return set( + _remove_empty_spaces(feature_flag, 'feature flag name', method_name) for feature_flag in feature_flags + if feature_flag is not None and + _check_is_string(feature_flag, 'feature flag name', method_name) and + _check_string_not_empty(feature_flag, 'feature flag name', method_name) + ) + + +def validate_feature_flags_get_treatments( # pylint: disable=invalid-name method_name, - features, - should_validate_existance=False, - split_storage=None -): + feature_flag_names, + ): """ - Check if features is valid for get_treatments. + Check if feature flags is valid for get_treatments. - :param features: array of features - :type features: list - :return: filtered_features + :param feature_flags: array of feature flags + :type feature_flags: list + :return: filtered_feature_flags :rtype: tuple """ - if features is None or not isinstance(features, list): - _LOGGER.error("%s: feature_names must be a non-empty array.", method_name) - return None, None - if not features: - _LOGGER.error("%s: feature_names must be a non-empty array.", method_name) - return None, None - filtered_features = set( - _remove_empty_spaces(feature, method_name) for feature in features - if feature is not None and - _check_is_string(feature, 'feature_name', method_name) and - _check_string_not_empty(feature, 'feature_name', method_name) - ) - if not filtered_features: - _LOGGER.error("%s: feature_names must be a non-empty array.", method_name) - return None, None - - if not should_validate_existance: - return filtered_features, [] + if not _check_feature_flag_instance(feature_flag_names, method_name): + return None - valid_missing_features = set(f for f in filtered_features if split_storage.get(f) is None) - for missing_feature in valid_missing_features: - _LOGGER.warning( - "%s: you passed \"%s\" that does not exist in this environment, " - "please double check what Splits exist in the web console.", - method_name, - missing_feature - ) - return filtered_features - valid_missing_features, valid_missing_features + filtered_feature_flags = _get_filtered_feature_flag(feature_flag_names, method_name) + if not filtered_feature_flags: + _LOGGER.error("%s: feature flag names must be a non-empty array.", method_name) + return None + valid_feature_flags = [] + for ff in filtered_feature_flags: + ff = _remove_empty_spaces(ff, 'feature flag name', method_name) + valid_feature_flags.append(ff) + return valid_feature_flags -def generate_control_treatments(features, method_name): +def generate_control_treatments(feature_flags, fallback_treatment_calculator): """ - Generate valid features to control. + Generate valid feature flags to control. - :param features: array of features - :type features: list + :param feature_flags: array of feature flags + :type feature_flags: list :return: dict :rtype: dict|None """ - return {feature: (CONTROL, None) for feature in validate_features_get_treatments(method_name, features)[0]} + if not isinstance(feature_flags, list): + return {} + + to_return = {} + for feature_flag in feature_flags: + if isinstance(feature_flag, str) and len(feature_flag.strip())> 0: + fallback_treatment = fallback_treatment_calculator.resolve(feature_flag, "") + treatment = fallback_treatment.treatment + config = fallback_treatment.config + + to_return[feature_flag] = (treatment, config) + return to_return def validate_attributes(attributes, method_name): @@ -437,61 +538,49 @@ def validate_attributes(attributes, method_name): """ if attributes is None: return True + if not isinstance(attributes, dict): _LOGGER.error('%s: attributes must be of type dictionary.', method_name) return False + return True +def validate_evaluation_options(evaluation_options, method_name): + if evaluation_options == None: + return None + + if not isinstance(evaluation_options, client.EvaluationOptions): + _LOGGER.error("%s: evaluation options should be an instance of EvaluationOptions. Setting its value to None.", method_name) + return None + + return evaluation_options class _ApiLogFilter(logging.Filter): # pylint: disable=too-few-public-methods def filter(self, record): return record.name not in ('SegmentsAPI', 'HttpClient') -def validate_apikey_type(segment_api): - """ - Try to guess if the apikey is of browser type and let the user know. - - :param segment_api: Segments API client. - :type segment_api: splitio.api.segments.SegmentsAPI - """ - api_messages_filter = _ApiLogFilter() - _logger = logging.getLogger('splitio.api.segments') - try: - _logger.addFilter(api_messages_filter) # pylint: disable=protected-access - segment_api.fetch_segment('__SOME_INVALID_SEGMENT__', -1, FetchOptions()) - except APIException as exc: - if exc.status_code == 403: - _LOGGER.error('factory instantiation: you passed a browser type ' - + 'api_key, please grab an api key from the Split ' - + 'console that is of type sdk') - return False - finally: - _logger.removeFilter(api_messages_filter) # pylint: disable=protected-access - - # True doesn't mean that the APIKEY is right, only that it's not of type "browser" - return True - - -def validate_factory_instantiation(apikey): +def validate_factory_instantiation(sdk_key): """ Check if the factory if being instantiated with the appropriate arguments. - :param apikey: str - :type apikey: str + :param sdk_key: str + :type sdk_key: str :return: bool :rtype: True|False """ - if apikey == 'localhost': + if sdk_key == 'localhost': return True - if (not _check_not_null(apikey, 'apikey', 'factory_instantiation')) or \ - (not _check_is_string(apikey, 'apikey', 'factory_instantiation')) or \ - (not _check_string_not_empty(apikey, 'apikey', 'factory_instantiation')): + + if (not _check_not_null(sdk_key, 'sdk_key', 'factory_instantiation')) or \ + (not _check_is_string(sdk_key, 'sdk_key', 'factory_instantiation')) or \ + (not _check_string_not_empty(sdk_key, 'sdk_key', 'factory_instantiation')): return False + return True -def valid_properties(properties): +def valid_properties(properties, source): """ Check if properties is a valid dict and returns the properties that will be sent to the track method, avoiding unexpected types. @@ -505,8 +594,9 @@ def valid_properties(properties): if properties is None: return True, None, size + if not isinstance(properties, dict): - _LOGGER.error('track: properties must be of type dictionary.') + _LOGGER.error('%s: properties must be of type dictionary.', source) return False, None, 0 valid_properties = dict() @@ -521,9 +611,8 @@ def valid_properties(properties): if element is None: continue - if not isinstance(element, str) and not isinstance(element, Number) \ - and not isinstance(element, bool): - _LOGGER.warning('Property %s is of invalid type. Setting value to None', element) + if not _check_element_type(element): + _LOGGER.warning('%s: Property %s is of invalid type. Setting value to None', source, element) element = None valid_properties[property] = element @@ -533,12 +622,124 @@ def valid_properties(properties): if size > MAX_PROPERTIES_LENGTH_BYTES: _LOGGER.error( - 'The maximum size allowed for the properties is 32768 bytes. ' + - 'Current one is ' + str(size) + ' bytes. Event not queued' - ) + '%s: The maximum size allowed for the properties is 32768 bytes. ' + + 'Current one is ' + str(size) + ' bytes. Event not queued', source) return False, None, size if len(valid_properties.keys()) > 300: - _LOGGER.warning('Event has more than 300 properties. Some of them will be trimmed' + - ' when processed') + _LOGGER.warning('%s: Event has more than 300 properties. Some of them will be trimmed' + + ' when processed', source) return True, valid_properties if len(valid_properties) else None, size + +def _check_element_type(element): + if not isinstance(element, str) and not isinstance(element, Number) \ + and not isinstance(element, bool): + return False + + return True + +def validate_pluggable_adapter(config): + """ + Check if pluggable adapter contains the expected method signature + + :param config: config parameters + :type config: Dict + + :return: True if no issue found otherwise False + :rtype: bool + """ + if config.get('storageType') != 'pluggable': + return True + + if config.get('storageWrapper') is None: + _LOGGER.error("Expecting pluggable storage `wrapper` in options, but no valid wrapper instance was provided.") + return False + + if config.get('storagePrefix') is not None: + if not isinstance(config.get('storagePrefix'), str): + _LOGGER.error("Pluggable storage prefix should be string type only") + return False + + pluggable_adapter = config.get('storageWrapper') + if not isinstance(pluggable_adapter, object): + _LOGGER.error("Pluggable storage instance is not inherted from object class") + return False + + expected_methods = {'get': 1, 'get_items': 1, 'get_many': 1, 'set': 2, 'push_items': 2, + 'delete': 1, 'increment': 2, 'decrement': 2, 'get_keys_by_prefix': 1, + 'get_many': 1, 'add_items' : 2, 'remove_items': 2, 'item_contains': 2, + 'get_items_count': 1, 'expire': 2} + methods = inspect.getmembers(pluggable_adapter, predicate=inspect.ismethod) + for exp_method in expected_methods: + method_found = False + get_method_args = set() + for method in methods: + if exp_method == method[0]: + method_found = True + get_method_args = inspect.signature(method[1]).parameters + break + + if not method_found: + _LOGGER.error("Pluggable adapter does not have required method: %s" % exp_method) + return False + + if len(get_method_args) < expected_methods[exp_method]: + _LOGGER.error("Pluggable adapter method %s has less than required arguments count: %s : " % (exp_method, len(get_method_args))) + return False + + return True + +def validate_flag_sets(flag_sets, method_name): + """ + Validate flag sets list + :param flag_set: list of flag sets + :type flag_set: list[str] + :returns: Sanitized and sorted flag sets + :rtype: list[str] + """ + if not isinstance(flag_sets, list): + _LOGGER.warning("%s: flag sets parameter type should be list object, parameter is discarded", method_name) + return [] + + sanitized_flag_sets = set() + for flag_set in flag_sets: + if not _check_not_null(flag_set, 'flag set', method_name): + continue + + if not _check_is_string(flag_set, 'flag set', method_name): + continue + + flag_set = _remove_empty_spaces(flag_set, 'flag set', method_name) + flag_set = _convert_str_to_lower(flag_set, 'flag set', method_name) + + if not _check_string_matches(flag_set, method_name, _FLAG_SETS_REGEX, 'a flag set', 50): + continue + + sanitized_flag_sets.add(flag_set) + + return list(sanitized_flag_sets) + +def validate_fallback_treatment(fallback_treatment): + if not isinstance(fallback_treatment, FallbackTreatment): + _LOGGER.warning("Config: Fallback treatment instance should be FallbackTreatment, input is discarded") + return False + + if not isinstance(fallback_treatment.treatment, str): + _LOGGER.warning("Config: Fallback treatment value should be str type, input is discarded") + return False + + if not validate_regex_name(fallback_treatment.treatment): + _LOGGER.warning("Config: Fallback treatment should match regex %s", _FALLBACK_TREATMENT_REGEX) + return False + + if len(fallback_treatment.treatment) > _FALLBACK_TREATMENT_SIZE: + _LOGGER.warning("Config: Fallback treatment size should not exceed %s characters", _FALLBACK_TREATMENT_SIZE) + return False + + return True + +def validate_regex_name(name): + if re.match(_FALLBACK_TREATMENT_REGEX, name) == None: + return False + + return True \ No newline at end of file diff --git a/splitio/client/listener.py b/splitio/client/listener.py index 3d2ea62c..aa5e815a 100644 --- a/splitio/client/listener.py +++ b/splitio/client/listener.py @@ -8,8 +8,20 @@ class ImpressionListenerException(Exception): pass +class ImpressionListener(object, metaclass=abc.ABCMeta): + """Impression listener interface.""" + + @abc.abstractmethod + def log_impression(self, data): + """ + Accept and impression generated after an evaluation for custom user handling. -class ImpressionListenerWrapper(object): # pylint: disable=too-few-public-methods + :param data: Impression data in a dictionary format. + :type data: dict + """ + pass + +class ImpressionListenerBase(ImpressionListener): # pylint: disable=too-few-public-methods """ Impression listener safe-execution wrapper. @@ -31,6 +43,35 @@ def __init__(self, impression_listener, sdk_metadata): self.impression_listener = impression_listener self._metadata = sdk_metadata + def _construct_data(self, impression, attributes): + data = {} + data['impression'] = impression + data['attributes'] = attributes + data['sdk-language-version'] = self._metadata.sdk_version + data['instance-id'] = self._metadata.instance_name + return data + + def log_impression(self, impression, attributes=None): + pass + +class ImpressionListenerWrapper(ImpressionListenerBase): # pylint: disable=too-few-public-methods + """ + Impression listener safe-execution wrapper. + + Wrapper in charge of building all the data that client would require in case + of adding some logic with the treatment and impression results. + """ + def __init__(self, impression_listener, sdk_metadata): + """ + Class Constructor. + + :param impression_listener: User provided impression listener. + :type impression_listener: ImpressionListener + :param sdk_metadata: SDK version, instance name & IP + :type sdk_metadata: splitio.client.util.SdkMetadata + """ + ImpressionListenerBase.__init__(self, impression_listener, sdk_metadata) + def log_impression(self, impression, attributes=None): """ Send an impression to the user-provided listener. @@ -40,26 +81,42 @@ def log_impression(self, impression, attributes=None): :param attributes: User provided attributes when calling get_treatment(s) :type attributes: dict """ - data = {} - data['impression'] = impression - data['attributes'] = attributes - data['sdk-language-version'] = self._metadata.sdk_version - data['instance-id'] = self._metadata.instance_name + data = self._construct_data(impression, attributes) try: self.impression_listener.log_impression(data) except Exception as exc: # pylint: disable=broad-except raise ImpressionListenerException('Error in log_impression user\'s method is throwing exceptions') from exc -class ImpressionListener(object, metaclass=abc.ABCMeta): - """Impression listener interface.""" +class ImpressionListenerWrapperAsync(ImpressionListenerBase): # pylint: disable=too-few-public-methods + """ + Impression listener safe-execution wrapper. - @abc.abstractmethod - def log_impression(self, data): + Wrapper in charge of building all the data that client would require in case + of adding some logic with the treatment and impression results. + """ + def __init__(self, impression_listener, sdk_metadata): """ - Accept and impression generated after an evaluation for custom user handling. + Class Constructor. - :param data: Impression data in a dictionary format. - :type data: dict + :param impression_listener: User provided impression listener. + :type impression_listener: ImpressionListener + :param sdk_metadata: SDK version, instance name & IP + :type sdk_metadata: splitio.client.util.SdkMetadata """ - pass + ImpressionListenerBase.__init__(self, impression_listener, sdk_metadata) + + async def log_impression(self, impression, attributes=None): + """ + Send an impression to the user-provided listener. + + :param impression: Imression data + :type impression: dict + :param attributes: User provided attributes when calling get_treatment(s) + :type attributes: dict + """ + data = self._construct_data(impression, attributes) + try: + await self.impression_listener.log_impression(data) + except Exception as exc: # pylint: disable=broad-except + raise ImpressionListenerException('Error in log_impression user\'s method is throwing exceptions') from exc diff --git a/splitio/client/localhost.py b/splitio/client/localhost.py index dec597a9..4cc87cc8 100644 --- a/splitio/client/localhost.py +++ b/splitio/client/localhost.py @@ -41,3 +41,34 @@ def pop_many(self, *_, **__): # pylint: disable=arguments-differ def clear(self, *_, **__): # pylint: disable=arguments-differ """Accept any arguments and do nothing.""" pass + +class LocalhostImpressionsStorageAsync(ImpressionStorage): + """Impression storage that doesn't cache anything.""" + + async def put(self, *_, **__): # pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + async def pop_many(self, *_, **__): # pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + async def clear(self, *_, **__): # pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + +class LocalhostEventsStorageAsync(EventStorage): + """Impression storage that doesn't cache anything.""" + + async def put(self, *_, **__): # pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + async def pop_many(self, *_, **__): # pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + async def clear(self, *_, **__): # pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass diff --git a/splitio/client/manager.py b/splitio/client/manager.py index dfb09f5a..e621aeb1 100644 --- a/splitio/client/manager.py +++ b/splitio/client/manager.py @@ -19,6 +19,7 @@ def __init__(self, factory): """ self._factory = factory self._storage = factory._get_storage('splits') # pylint: disable=protected-access + self._telemetry_init_producer = factory._telemetry_init_producer def split_names(self): """ @@ -30,11 +31,13 @@ def split_names(self): if self._factory.destroyed: _LOGGER.error("Client has already been destroyed - no calls possible.") return [] + if self._factory._waiting_fork(): _LOGGER.error("Client is not ready - no calls possible") return [] if not self._factory.ready: + self._telemetry_init_producer.record_not_ready_usage() _LOGGER.warning( "split_names: The SDK is not ready, results may be incorrect. " "Make sure to wait for SDK readiness before using this method" @@ -52,11 +55,13 @@ def splits(self): if self._factory.destroyed: _LOGGER.error("Client has already been destroyed - no calls possible.") return [] + if self._factory._waiting_fork(): _LOGGER.error("Client is not ready - no calls possible") return [] if not self._factory.ready: + self._telemetry_init_producer.record_not_ready_usage() _LOGGER.warning( "splits: The SDK is not ready, results may be incorrect. " "Make sure to wait for SDK readiness before using this method" @@ -77,24 +82,117 @@ def split(self, feature_name): if self._factory.destroyed: _LOGGER.error("Client has already been destroyed - no calls possible.") return None + if self._factory._waiting_fork(): _LOGGER.error("Client is not ready - no calls possible") return None - feature_name = input_validator.validate_manager_feature_name( + feature_flag = input_validator.validate_manager_feature_flag_name( feature_name, self._factory.ready, self._storage ) if not self._factory.ready: + self._telemetry_init_producer.record_not_ready_usage() _LOGGER.warning( "split: The SDK is not ready, results may be incorrect. " "Make sure to wait for SDK readiness before using this method" ) - if feature_name is None: + return feature_flag.to_split_view() if feature_flag is not None else None + +class SplitManagerAsync(object): + """Split Manager. Gives insights on data cached by splits.""" + + def __init__(self, factory): + """ + Class constructor. + + :param factory: Factory containing all storage references. + :type factory: splitio.client.factory.SplitFactory + """ + self._factory = factory + self._storage = factory._get_storage('splits') # pylint: disable=protected-access + self._telemetry_init_producer = factory._telemetry_init_producer + + async def split_names(self): + """ + Get the name of fetched splits. + + :return: A list of str + :rtype: list + """ + if self._factory.destroyed: + _LOGGER.error("Client has already been destroyed - no calls possible.") + return [] + + if self._factory._waiting_fork(): + _LOGGER.error("Client is not ready - no calls possible") + return [] + + if not self._factory.ready: + await self._telemetry_init_producer.record_not_ready_usage() + _LOGGER.warning( + "split_names: The SDK is not ready, results may be incorrect. " + "Make sure to wait for SDK readiness before using this method" + ) + + return await self._storage.get_split_names() + + async def splits(self): + """ + Get the fetched splits. Subclasses need to override this method. + + :return: A List of SplitView. + :rtype: list() + """ + if self._factory.destroyed: + _LOGGER.error("Client has already been destroyed - no calls possible.") + return [] + + if self._factory._waiting_fork(): + _LOGGER.error("Client is not ready - no calls possible") + return [] + + if not self._factory.ready: + await self._telemetry_init_producer.record_not_ready_usage() + _LOGGER.warning( + "splits: The SDK is not ready, results may be incorrect. " + "Make sure to wait for SDK readiness before using this method" + ) + + return [split.to_split_view() for split in await self._storage.get_all_splits()] + + async def split(self, feature_name): + """ + Get the splitView of feature_name. Subclasses need to override this method. + + :param feature_name: Name of the feture to retrieve. + :type feature_name: str + + :return: The SplitView instance. + :rtype: splitio.models.splits.SplitView + """ + if self._factory.destroyed: + _LOGGER.error("Client has already been destroyed - no calls possible.") return None - split = self._storage.get(feature_name) - return split.to_split_view() if split is not None else None + if self._factory._waiting_fork(): + _LOGGER.error("Client is not ready - no calls possible") + return None + + feature_flag = await input_validator.validate_manager_feature_flag_name_async( + feature_name, + self._factory.ready, + self._storage + ) + + if not self._factory.ready: + await self._telemetry_init_producer.record_not_ready_usage() + _LOGGER.warning( + "split: The SDK is not ready, results may be incorrect. " + "Make sure to wait for SDK readiness before using this method" + ) + + return feature_flag.to_split_view() if feature_flag is not None else None diff --git a/splitio/client/util.py b/splitio/client/util.py index 040a09ae..b5b693cb 100644 --- a/splitio/client/util.py +++ b/splitio/client/util.py @@ -30,6 +30,7 @@ def _get_hostname(ip_address): def _get_hostname_and_ip(config): if config.get('IPAddressesEnabled') is False: return 'NA', 'NA' + ip_from_config = config.get('machineIp') machine_from_config = config.get('machineName') ip_address = ip_from_config if ip_from_config is not None else _get_ip() @@ -49,4 +50,4 @@ def get_metadata(config): """ version = 'python-%s' % __version__ ip_address, hostname = _get_hostname_and_ip(config) - return SdkMetadata(version, hostname, ip_address) + return SdkMetadata(version, hostname, ip_address) \ No newline at end of file diff --git a/splitio/engine/evaluator.py b/splitio/engine/evaluator.py index 489c9ba2..b47db5c5 100644 --- a/splitio/engine/evaluator.py +++ b/splitio/engine/evaluator.py @@ -1,11 +1,18 @@ """Split evaluator module.""" import logging -from splitio.models.grammar.condition import ConditionType -from splitio.models.impressions import Label +from collections import namedtuple +from splitio.models.impressions import Label +from splitio.models.grammar.condition import ConditionType +from splitio.models.grammar.matchers.misc import DependencyMatcher +from splitio.models.grammar.matchers.keys import UserDefinedSegmentMatcher +from splitio.models.grammar.matchers import RuleBasedSegmentMatcher +from splitio.models.grammar.matchers.prerequisites import PrerequisitesMatcher +from splitio.models.rule_based_segments import SegmentType +from splitio.optional.loaders import asyncio CONTROL = 'control' - +EvaluationContext = namedtuple('EvaluationContext', ['flags', 'segment_memberships', 'rbs_segments']) _LOGGER = logging.getLogger(__name__) @@ -13,189 +20,236 @@ class Evaluator(object): # pylint: disable=too-few-public-methods """Split Evaluator class.""" - def __init__(self, split_storage, segment_storage, splitter): + def __init__(self, splitter, fallback_treatment_calculator=None): """ Construct a Evaluator instance. - :param split_storage: Split storage. - :type split_storage: splitio.storage.SplitStorage - - :param split_storage: Storage storage. - :type split_storage: splitio.storage.SegmentStorage + :param splitter: partition object. + :type splitter: splitio.engine.splitters.Splitters """ - self._split_storage = split_storage - self._segment_storage = segment_storage self._splitter = splitter + self._fallback_treatment_calculator = fallback_treatment_calculator - def _evaluate_treatment(self, feature, matching_key, bucketing_key, attributes, split): + def eval_many_with_context(self, key, bucketing, features, attrs, ctx): """ - Evaluate the user submitted data against a feature and return the resulting treatment. - - :param feature: The feature for which to get the treatment - :type feature: str - - :param matching_key: The matching_key for which to get the treatment - :type matching_key: str - - :param bucketing_key: The bucketing_key for which to get the treatment - :type bucketing_key: str - - :param attributes: An optional dictionary of attributes - :type attributes: dict - - :param split: Split object - :type attributes: splitio.models.splits.Split|None + ... + """ + # we can do a linear evaluation here, since all the dependencies are already fetched + return { + name: self.eval_with_context(key, bucketing, name, attrs, ctx) + for name in features + } - :return: The treatment for the key and split - :rtype: object + def eval_with_context(self, key, bucketing, feature_name, attrs, ctx): + """ + ... """ label = '' _treatment = CONTROL _change_number = -1 - if split is None: + feature = ctx.flags.get(feature_name) + if not feature: _LOGGER.warning('Unknown or invalid feature: %s', feature) label = Label.SPLIT_NOT_FOUND + fallback_treatment = self._fallback_treatment_calculator.resolve(feature_name, label) + label = fallback_treatment.label + _treatment = fallback_treatment.treatment + config = fallback_treatment.config else: - _change_number = split.change_number - if split.killed: + _change_number = feature.change_number + if feature.killed: label = Label.KILLED - _treatment = split.default_treatment + _treatment = feature.default_treatment else: - treatment, label = self._get_treatment_for_split( - split, - matching_key, - bucketing_key, - attributes - ) - if treatment is None: - label = Label.NO_CONDITION_MATCHED - _treatment = split.default_treatment - else: - _treatment = treatment - + label, _treatment = self._check_prerequisites(feature, bucketing, key, attrs, ctx, label, _treatment) + label, _treatment = self._get_treatment(feature, bucketing, key, attrs, ctx, label, _treatment) + config = feature.get_configurations_for(_treatment) + return { 'treatment': _treatment, - 'configurations': split.get_configurations_for(_treatment) if split else None, + 'configurations': config, 'impression': { 'label': label, 'change_number': _change_number - } + }, + 'impressions_disabled': feature.impressions_disabled if feature else None } - - def evaluate_feature(self, feature, matching_key, bucketing_key, attributes=None): + + def _get_treatment(self, feature, bucketing, key, attrs, ctx, label, _treatment): + if _treatment == CONTROL: + treatment, label = self._treatment_for_flag(feature, key, bucketing, attrs, ctx) + if treatment is None: + label = Label.NO_CONDITION_MATCHED + _treatment = feature.default_treatment + else: + _treatment = treatment + + return label, _treatment + + def _check_prerequisites(self, feature, bucketing, key, attrs, ctx, label, _treatment): + if feature.prerequisites is not None: + prerequisites_matcher = PrerequisitesMatcher(feature.prerequisites) + if not prerequisites_matcher.match(key, attrs, { + 'evaluator': self, + 'bucketing_key': bucketing, + 'ec': ctx}): + label = Label.PREREQUISITES_NOT_MET + _treatment = feature.default_treatment + + return label, _treatment + + + def _treatment_for_flag(self, flag, key, bucketing, attributes, ctx): """ - Evaluate the user submitted data against a feature and return the resulting treatment. - - :param feature: The feature for which to get the treatment - :type feature: str + ... + """ + bucketing = bucketing if bucketing is not None else key + rollout = False + for condition in flag.conditions: + if not rollout and condition.condition_type == ConditionType.ROLLOUT: + if flag.traffic_allocation < 100: + bucket = self._splitter.get_bucket(bucketing, flag.traffic_allocation_seed, flag.algo) + if bucket > flag.traffic_allocation: + return flag.default_treatment, Label.NOT_IN_SPLIT - :param matching_key: The matching_key for which to get the treatment - :type matching_key: str + rollout = True - :param bucketing_key: The bucketing_key for which to get the treatment - :type bucketing_key: str + if condition.matches(key, attributes, { + 'evaluator': self, + 'bucketing_key': bucketing, + 'ec': ctx, + }): - :param attributes: An optional dictionary of attributes - :type attributes: dict + return self._splitter.get_treatment(bucketing, flag.seed, condition.partitions, flag.algo), condition.label - :return: The treatment for the key and split - :rtype: object - """ - # Fetching Split definition - split = self._split_storage.get(feature) + return flag.default_treatment, Label.NO_CONDITION_MATCHED - # Calling evaluation - evaluation = self._evaluate_treatment(feature, matching_key, - bucketing_key, attributes, split) +class EvaluationDataFactory: - return evaluation + def __init__(self, split_storage, segment_storage, rbs_segment_storage): + self._flag_storage = split_storage + self._segment_storage = segment_storage + self._rbs_segment_storage = rbs_segment_storage - def evaluate_features(self, features, matching_key, bucketing_key, attributes=None): + def context_for(self, key, feature_names): """ - Evaluate the user submitted data against multiple features and return the resulting - treatment. - - :param features: The features for which to get the treatments - :type feature: list(str) - - :param matching_key: The matching_key for which to get the treatment - :type matching_key: str - - :param bucketing_key: The bucketing_key for which to get the treatment + Recursively iterate & fetch all data required to evaluate these flags. + :type features: list :type bucketing_key: str - - :param attributes: An optional dictionary of attributes :type attributes: dict - :return: The treatments for the key and splits - :rtype: object + :rtype: EvaluationContext """ - return { - feature: self._evaluate_treatment(feature, matching_key, - bucketing_key, attributes, split) - for (feature, split) in self._split_storage.fetch_many(features).items() - } - - def _get_treatment_for_split(self, split, matching_key, bucketing_key, attributes=None): + pending = set(feature_names) + pending_rbs = set() + splits = {} + rb_segments = {} + pending_memberships = set() + while pending or pending_rbs: + fetched = self._flag_storage.fetch_many(list(pending)) + fetched_rbs = self._rbs_segment_storage.fetch_many(list(pending_rbs)) + features, rbsegments, splits, rb_segments = update_objects(fetched, fetched_rbs, splits, rb_segments) + pending, pending_memberships, pending_rbs = get_pending_objects(features, splits, rbsegments, rb_segments, pending_memberships) + + return EvaluationContext( + splits, + { segment: self._segment_storage.segment_contains(segment, key) + for segment in pending_memberships + }, + rb_segments + ) + +class AsyncEvaluationDataFactory: + + def __init__(self, split_storage, segment_storage, rbs_segment_storage): + self._flag_storage = split_storage + self._segment_storage = segment_storage + self._rbs_segment_storage = rbs_segment_storage + + async def context_for(self, key, feature_names): """ - Evaluate the feature considering the conditions. - - If there is a match, it will return the condition and the label. - Otherwise, it will return (None, None) - - :param split: The split for which to get the treatment - :type split: Split - - :param matching_key: The key for which to get the treatment - :type key: str - - :param bucketing_key: The key for which to get the treatment - :type key: str - - :param attributes: An optional dictionary of attributes + Recursively iterate & fetch all data required to evaluate these flags. + :type features: list + :type bucketing_key: str :type attributes: dict - :return: The resulting treatment and label - :rtype: tuple + :rtype: EvaluationContext """ - if bucketing_key is None: - bucketing_key = matching_key - - roll_out = False - - context = { - 'segment_storage': self._segment_storage, - 'evaluator': self, - 'bucketing_key': bucketing_key - } - - for condition in split.conditions: - if (not roll_out and - condition.condition_type == ConditionType.ROLLOUT): - if split.traffic_allocation < 100: - bucket = self._splitter.get_bucket( - bucketing_key, - split.traffic_allocation_seed, - split.algo - ) - if bucket > split.traffic_allocation: - return split.default_treatment, Label.NOT_IN_SPLIT - roll_out = True - - condition_matches = condition.matches( - matching_key, - attributes=attributes, - context=context - ) - - if condition_matches: - return self._splitter.get_treatment( - bucketing_key, - split.seed, - condition.partitions, - split.algo - ), condition.label - - # No condition matches - return None, None + pending = set(feature_names) + pending_rbs = set() + splits = {} + rb_segments = {} + pending_memberships = set() + while pending or pending_rbs: + fetched = await self._flag_storage.fetch_many(list(pending)) + fetched_rbs = await self._rbs_segment_storage.fetch_many(list(pending_rbs)) + features, rbsegments, splits, rb_segments = update_objects(fetched, fetched_rbs, splits, rb_segments) + pending, pending_memberships, pending_rbs = get_pending_objects(features, splits, rbsegments, rb_segments, pending_memberships) + + segment_names = list(pending_memberships) + segment_memberships = await asyncio.gather(*[ + self._segment_storage.segment_contains(segment, key) + for segment in segment_names + ]) + + return EvaluationContext( + splits, + dict(zip(segment_names, segment_memberships)), + rb_segments + ) + +def get_dependencies(object): + """ + :rtype: tuple(list, list) + """ + feature_names = [] + segment_names = [] + rbs_segment_names = [] + for condition in object.conditions: + for matcher in condition.matchers: + if isinstance(matcher,RuleBasedSegmentMatcher): + rbs_segment_names.append(matcher._rbs_segment_name) + if isinstance(matcher,UserDefinedSegmentMatcher): + segment_names.append(matcher._segment_name) + elif isinstance(matcher, DependencyMatcher): + feature_names.append(matcher._split_name) + + return feature_names, segment_names, rbs_segment_names + +def filter_missing(features): + return {k: v for (k, v) in features.items() if v is not None} + +def get_pending_objects(features, splits, rbsegments, rb_segments, pending_memberships): + pending = set() + pending_rbs = set() + for feature in features.values(): + cf, cs, crbs = get_dependencies(feature) + cf.extend(get_prerequisites(feature)) + pending.update(filter(lambda f: f not in splits, cf)) + pending_memberships.update(cs) + pending_rbs.update(filter(lambda f: f not in rb_segments, crbs)) + + for rb_segment in rbsegments.values(): + cf, cs, crbs = get_dependencies(rb_segment) + pending.update(filter(lambda f: f not in splits, cf)) + pending_memberships.update(cs) + for excluded_segment in rb_segment.excluded.get_excluded_segments(): + if excluded_segment.type == SegmentType.STANDARD: + pending_memberships.add(excluded_segment.name) + else: + pending_rbs.update(filter(lambda f: f not in rb_segments, [excluded_segment.name])) + pending_rbs.update(filter(lambda f: f not in rb_segments, crbs)) + + return pending, pending_memberships, pending_rbs + +def update_objects(fetched, fetched_rbs, splits, rb_segments): + features = filter_missing(fetched) + rbsegments = filter_missing(fetched_rbs) + splits.update(features) + rb_segments.update(rbsegments) + + return features, rbsegments, splits, rb_segments + +def get_prerequisites(feature): + return [prerequisite.feature_flag_name for prerequisite in feature.prerequisites] diff --git a/splitio/engine/filters.py b/splitio/engine/filters.py new file mode 100644 index 00000000..6c16be8f --- /dev/null +++ b/splitio/engine/filters.py @@ -0,0 +1,85 @@ +import abc +import threading + +from bloom_filter2 import BloomFilter as BloomFilter2 + +class BaseFilter(object, metaclass=abc.ABCMeta): + """Impressions Filter interface.""" + + @abc.abstractmethod + def add(self, data): + """ + Return a boolean flag + + """ + pass + + @abc.abstractmethod + def contains(self, data): + """ + Return a boolean flag + + """ + pass + + @abc.abstractmethod + def clear(self): + """ + No return + + """ + pass + +class BloomFilter(BaseFilter): + """Optimized mode strategy.""" + + def __init__(self, max_elements=5000, error_rate=0.01): + """ + Construct a bloom filter instance. + + :param max_element: maximum elements in the filter + :type string: + + :param error_rate: error rate for the false positives, reduce it will consume more memory + :type numeric: + """ + self._max_elements = max_elements + self._error_rate = error_rate + self._imps_bloom_filter = BloomFilter2(max_elements=self._max_elements, error_rate=self._error_rate) + self._lock = threading.RLock() + + def add(self, data): + """ + Add an item to the bloom filter instance. + + :param data: element to be added + :type string: + + :return: True if successful + :rtype: boolean + """ + with self._lock: + self._imps_bloom_filter.add(data) + return data in self._imps_bloom_filter + + def contains(self, data): + """ + Check if an item exist in the bloom filter instance. + + :param data: element to be checked + :type string: + + :return: True if exist + :rtype: boolean + """ + with self._lock: + return data in self._imps_bloom_filter + + def clear(self): + """ + Destroy the current filter instance and create new one. + + """ + with self._lock: + self._imps_bloom_filter.close() + self._imps_bloom_filter = BloomFilter2(max_elements=self._max_elements, error_rate=self._error_rate) diff --git a/splitio/engine/hashfns/legacy.py b/splitio/engine/hashfns/legacy.py index 1a2dc267..bb461d4f 100644 --- a/splitio/engine/hashfns/legacy.py +++ b/splitio/engine/hashfns/legacy.py @@ -5,6 +5,7 @@ def as_int32(value): """Handle overflow when working with 32 lower bits of 64 bit ints.""" if not -2147483649 <= value <= 2147483648: return (value + 2147483648) % 4294967296 - 2147483648 + return value diff --git a/splitio/engine/impressions/__init__.py b/splitio/engine/impressions/__init__.py new file mode 100644 index 00000000..fdd84211 --- /dev/null +++ b/splitio/engine/impressions/__init__.py @@ -0,0 +1,138 @@ +from splitio.engine.impressions.impressions import ImpressionsMode +from splitio.engine.impressions.strategies import StrategyNoneMode, StrategyDebugMode, StrategyOptimizedMode +from splitio.engine.impressions.adapters import InMemorySenderAdapter, RedisSenderAdapter, PluggableSenderAdapter, RedisSenderAdapterAsync, \ + InMemorySenderAdapterAsync, PluggableSenderAdapterAsync +from splitio.tasks.unique_keys_sync import UniqueKeysSyncTask, ClearFilterSyncTask, UniqueKeysSyncTaskAsync, ClearFilterSyncTaskAsync +from splitio.sync.unique_keys import UniqueKeysSynchronizer, ClearFilterSynchronizer, UniqueKeysSynchronizerAsync, ClearFilterSynchronizerAsync +from splitio.sync.impression import ImpressionsCountSynchronizer, ImpressionsCountSynchronizerAsync +from splitio.tasks.impressions_sync import ImpressionsCountSyncTask, ImpressionsCountSyncTaskAsync + +def set_classes(storage_mode, impressions_mode, api_adapter, imp_counter, unique_keys_tracker, prefix=None): + """ + Createe and return instances based on storage, impressions and threading mode + + :param storage_mode: storage mode (MEMORY, REDIS or PLUGGABLE) + :type storage_mode: str + :param impressions_mode: impressions mode used + :type impressions_mode: splitio.engine.impressions.impressions.ImpressionsMode + :param api_adapter: api adapter instance(s) + :type impressions_mode: dict or splitio.storage.adapters.redis.RedisAdapter/splitio.storage.adapters.redis.RedisAdapterAsync + :param imp_counter: Impressions Counter instance + :type imp_counter: splitio.engine.impressions.Counter/splitio.engine.impressions.Counter + :param unique_keys_tracker: Unique Keys Tracker instance + :type unique_keys_tracker: splitio.engine.unique_keys_tracker.UniqueKeysTracker/splitio.engine.unique_keys_tracker.UniqueKeysTrackerAsync + :param prefix: Prefix used for redis or pluggable adapters + :type prefix: str + + :return: tuple of classes instances. + :rtype: (splitio.sync.unique_keys.UniqueKeysSynchronizer, + splitio.sync.unique_keys.ClearFilterSynchronizer, + splitio.tasks.unique_keys_sync.UniqueKeysTask, + splitio.tasks.unique_keys_sync.ClearFilterTask, + splitio.sync.impressions_sync.ImpressionsCountSynchronizer, + splitio.tasks.impressions_sync.ImpressionsCountSyncTask, + splitio.engine.impressions.strategies.StrategyNoneMode/splitio.engine.impressions.strategies.StrategyDebugMode/splitio.engine.impressions.strategies.StrategyOptimizedMode) + """ + unique_keys_synchronizer = None + clear_filter_sync = None + unique_keys_task = None + clear_filter_task = None + impressions_count_sync = None + impressions_count_task = None + sender_adapter = None + if storage_mode == 'PLUGGABLE': + sender_adapter = PluggableSenderAdapter(api_adapter, prefix) + api_telemetry_adapter = sender_adapter + api_impressions_adapter = sender_adapter + elif storage_mode == 'REDIS': + sender_adapter = RedisSenderAdapter(api_adapter) + api_telemetry_adapter = sender_adapter + api_impressions_adapter = sender_adapter + else: + api_telemetry_adapter = api_adapter['telemetry'] + api_impressions_adapter = api_adapter['impressions'] + sender_adapter = InMemorySenderAdapter(api_telemetry_adapter) + + none_strategy = StrategyNoneMode() + unique_keys_synchronizer = UniqueKeysSynchronizer(sender_adapter, unique_keys_tracker) + unique_keys_task = UniqueKeysSyncTask(unique_keys_synchronizer.send_all) + clear_filter_sync = ClearFilterSynchronizer(unique_keys_tracker) + impressions_count_sync = ImpressionsCountSynchronizer(api_impressions_adapter, imp_counter) + impressions_count_task = ImpressionsCountSyncTask(impressions_count_sync.synchronize_counters) + clear_filter_task = ClearFilterSyncTask(clear_filter_sync.clear_all) + unique_keys_tracker.set_queue_full_hook(unique_keys_task.flush) + + if impressions_mode == ImpressionsMode.NONE: + imp_strategy = StrategyNoneMode() + elif impressions_mode == ImpressionsMode.DEBUG: + imp_strategy = StrategyDebugMode() + else: + imp_strategy = StrategyOptimizedMode() + + return unique_keys_synchronizer, clear_filter_sync, unique_keys_task, clear_filter_task, \ + impressions_count_sync, impressions_count_task, imp_strategy, none_strategy + +def set_classes_async(storage_mode, impressions_mode, api_adapter, imp_counter, unique_keys_tracker, prefix=None): + """ + Createe and return instances based on storage, impressions and async mode + + :param storage_mode: storage mode (MEMORY, REDIS or PLUGGABLE) + :type storage_mode: str + :param impressions_mode: impressions mode used + :type impressions_mode: splitio.engine.impressions.impressions.ImpressionsMode + :param api_adapter: api adapter instance(s) + :type impressions_mode: dict or splitio.storage.adapters.redis.RedisAdapter/splitio.storage.adapters.redis.RedisAdapterAsync + :param imp_counter: Impressions Counter instance + :type imp_counter: splitio.engine.impressions.Counter/splitio.engine.impressions.Counter + :param unique_keys_tracker: Unique Keys Tracker instance + :type unique_keys_tracker: splitio.engine.unique_keys_tracker.UniqueKeysTracker/splitio.engine.unique_keys_tracker.UniqueKeysTrackerAsync + :param prefix: Prefix used for redis or pluggable adapters + :type prefix: str + + :return: tuple of classes instances. + :rtype: (splitio.sync.unique_keys.UniqueKeysSynchronizerAsync, + splitio.sync.unique_keys.ClearFilterSynchronizerAsync, + splitio.tasks.unique_keys_sync.UniqueKeysTaskAsync, + splitio.tasks.unique_keys_sync.ClearFilterTaskAsync, + splitio.sync.impressions_sync.ImpressionsCountSynchronizerAsync, + splitio.tasks.impressions_sync.ImpressionsCountSyncTaskAsync, + splitio.engine.impressions.strategies.StrategyNoneMode/splitio.engine.impressions.strategies.StrategyDebugMode/splitio.engine.impressions.strategies.StrategyOptimizedMode) + """ + unique_keys_synchronizer = None + clear_filter_sync = None + unique_keys_task = None + clear_filter_task = None + impressions_count_sync = None + impressions_count_task = None + sender_adapter = None + if storage_mode == 'PLUGGABLE': + sender_adapter = PluggableSenderAdapterAsync(api_adapter, prefix) + api_telemetry_adapter = sender_adapter + api_impressions_adapter = sender_adapter + elif storage_mode == 'REDIS': + sender_adapter = RedisSenderAdapterAsync(api_adapter) + api_telemetry_adapter = sender_adapter + api_impressions_adapter = sender_adapter + else: + api_telemetry_adapter = api_adapter['telemetry'] + api_impressions_adapter = api_adapter['impressions'] + sender_adapter = InMemorySenderAdapterAsync(api_telemetry_adapter) + + none_strategy = StrategyNoneMode() + unique_keys_synchronizer = UniqueKeysSynchronizerAsync(sender_adapter, unique_keys_tracker) + unique_keys_task = UniqueKeysSyncTaskAsync(unique_keys_synchronizer.send_all) + clear_filter_sync = ClearFilterSynchronizerAsync(unique_keys_tracker) + impressions_count_sync = ImpressionsCountSynchronizerAsync(api_impressions_adapter, imp_counter) + impressions_count_task = ImpressionsCountSyncTaskAsync(impressions_count_sync.synchronize_counters) + clear_filter_task = ClearFilterSyncTaskAsync(clear_filter_sync.clear_all) + unique_keys_tracker.set_queue_full_hook(unique_keys_task.flush) + + if impressions_mode == ImpressionsMode.NONE: + imp_strategy = StrategyNoneMode() + elif impressions_mode == ImpressionsMode.DEBUG: + imp_strategy = StrategyDebugMode() + else: + imp_strategy = StrategyOptimizedMode() + + return unique_keys_synchronizer, clear_filter_sync, unique_keys_task, clear_filter_task, \ + impressions_count_sync, impressions_count_task, imp_strategy, none_strategy diff --git a/splitio/engine/impressions/adapters.py b/splitio/engine/impressions/adapters.py new file mode 100644 index 00000000..d5e3dcaf --- /dev/null +++ b/splitio/engine/impressions/adapters.py @@ -0,0 +1,397 @@ +import abc +import logging +import json + +from splitio.storage.adapters.redis import RedisAdapterException + +_LOGGER = logging.getLogger(__name__) +_MTK_QUEUE_KEY = 'SPLITIO.uniquekeys' +_MTK_KEY_DEFAULT_TTL = 3600 +_IMP_COUNT_QUEUE_KEY = 'SPLITIO.impressions.count' +_IMP_COUNT_KEY_DEFAULT_TTL = 3600 + +class ImpressionsSenderAdapter(object, metaclass=abc.ABCMeta): + """Impressions Sender Adapter interface.""" + + @abc.abstractmethod + def record_unique_keys(self, data): + """ + No Return value + + """ + pass + +class InMemorySenderAdapterBase(ImpressionsSenderAdapter): + """In Memory Impressions Sender Adapter base class.""" + + def record_unique_keys(self, uniques): + """ + post the unique keys to split back end. + + :param uniques: unique keys disctionary + :type uniques: Dictionary {'feature_flag1': set(), 'feature_flag2': set(), .. } + """ + pass + + def _uniques_formatter(self, uniques): + """ + Format the unique keys dictionary array to a JSON body + + :param uniques: unique keys disctionary + :type uniques: Dictionary {'feature1_flag': set(), 'feature2_flag': set(), .. } + + :return: unique keys JSON array + :rtype: json + """ + return [{'f': feature, 'ks': list(keys)} for feature, keys in uniques.items()] + +class InMemorySenderAdapter(InMemorySenderAdapterBase): + """In Memory Impressions Sender Adapter class.""" + + def __init__(self, telemtry_http_client): + """ + Initialize In memory sender adapter instance + + :param telemtry_http_client: instance of telemetry http api + :type telemtry_http_client: splitio.api.telemetry.TelemetryAPI + """ + self._telemtry_http_client = telemtry_http_client + + def record_unique_keys(self, uniques): + """ + post the unique keys to split back end. + + :param uniques: unique keys disctionary + :type uniques: Dictionary {'feature_flag1': set(), 'feature_flag2': set(), .. } + """ + if len(uniques) == 0: + return + + self._telemtry_http_client.record_unique_keys({'keys': self._uniques_formatter(uniques)}) + + +class InMemorySenderAdapterAsync(InMemorySenderAdapterBase): + """In Memory Impressions Sender Adapter class.""" + + def __init__(self, telemtry_http_client): + """ + Initialize In memory sender adapter instance + + :param telemtry_http_client: instance of telemetry http api + :type telemtry_http_client: splitio.api.telemetry.TelemetryAPI + """ + self._telemtry_http_client = telemtry_http_client + + async def record_unique_keys(self, uniques): + """ + post the unique keys to split back end. + + :param uniques: unique keys disctionary + :type uniques: Dictionary {'feature_flag1': set(), 'feature_flag2': set(), .. } + """ + if len(uniques) == 0: + return + + await self._telemtry_http_client.record_unique_keys({'keys': self._uniques_formatter(uniques)}) + + +class RedisSenderAdapter(ImpressionsSenderAdapter): + """Redis Impressions Sender Adapter class.""" + + def __init__(self, redis_client): + """ + Initialize Redis sender adapter instance + + :param telemtry_http_client: instance of telemetry http api + :type telemtry_http_client: splitio.api.telemetry.TelemetryAPI + """ + self._redis_client = redis_client + + def record_unique_keys(self, uniques): + """ + post the unique keys to redis. + + :param uniques: unique keys disctionary + :type uniques: Dictionary {'feature_flag1': set(), 'feature_flag2': set(), .. } + """ + if len(uniques) == 0: + return + + bulk_mtks = _uniques_formatter(uniques) + try: + inserted = self._redis_client.rpush(_MTK_QUEUE_KEY, *bulk_mtks) + self._expire_keys(_MTK_QUEUE_KEY, _MTK_KEY_DEFAULT_TTL, inserted, len(bulk_mtks)) + return True + + except RedisAdapterException: + _LOGGER.error('Something went wrong when trying to add mtks to redis') + _LOGGER.error('Error: ', exc_info=True) + return False + + def flush_counters(self, to_send): + """ + post the impression counters to redis. + + :param to_send: unique keys disctionary + :type to_send: Dictionary {'feature_flag1': set(), 'feature_flag2': set(), .. } + """ + if len(to_send) == 0: + return + + try: + resulted = 0 + counted = 0 + pipe = self._redis_client.pipeline() + for pf_count in to_send: + pipe.hincrby(_IMP_COUNT_QUEUE_KEY, pf_count.feature + "::" + str(pf_count.timeframe), pf_count.count) + counted += pf_count.count + resulted = sum(pipe.execute()) + self._expire_keys(_IMP_COUNT_QUEUE_KEY, + _IMP_COUNT_KEY_DEFAULT_TTL, resulted, counted) + return True + + except RedisAdapterException: + _LOGGER.error('Something went wrong when trying to add counters to redis') + _LOGGER.error('Error: ', exc_info=True) + return False + + def _expire_keys(self, queue_key, key_default_ttl, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + self._redis_client.expire(queue_key, key_default_ttl) + + +class RedisSenderAdapterAsync(ImpressionsSenderAdapter): + """In Redis Impressions Sender Adapter async class.""" + + def __init__(self, redis_client): + """ + Initialize Redis sender adapter instance + + :param telemtry_http_client: instance of telemetry http api + :type telemtry_http_client: splitio.api.telemetry.TelemetryAPI + """ + self._redis_client = redis_client + + async def record_unique_keys(self, uniques): + """ + post the unique keys to redis. + + :param uniques: unique keys disctionary + :type uniques: Dictionary {'feature_flag1': set(), 'feature_flag2': set(), .. } + """ + if len(uniques) == 0: + return True + + bulk_mtks = _uniques_formatter(uniques) + try: + inserted = await self._redis_client.rpush(_MTK_QUEUE_KEY, *bulk_mtks) + await self._expire_keys(_MTK_QUEUE_KEY, _MTK_KEY_DEFAULT_TTL, inserted, len(bulk_mtks)) + return True + + except RedisAdapterException: + _LOGGER.error('Something went wrong when trying to add mtks to redis') + _LOGGER.error('Error: ', exc_info=True) + return False + + async def flush_counters(self, to_send): + """ + post the impression counters to redis. + + :param to_send: unique keys disctionary + :type to_send: Dictionary {'feature_flag1': set(), 'feature_flag2': set(), .. } + """ + if len(to_send) == 0: + return True + + try: + resulted = 0 + counted = 0 + pipe = self._redis_client.pipeline() + for pf_count in to_send: + pipe.hincrby(_IMP_COUNT_QUEUE_KEY, pf_count.feature + "::" + str(pf_count.timeframe), pf_count.count) + counted += pf_count.count + resulted = sum(await pipe.execute()) + await self._expire_keys(_IMP_COUNT_QUEUE_KEY, + _IMP_COUNT_KEY_DEFAULT_TTL, resulted, counted) + return True + + except RedisAdapterException: + _LOGGER.error('Something went wrong when trying to add counters to redis') + _LOGGER.error('Error: ', exc_info=True) + return False + + async def _expire_keys(self, queue_key, key_default_ttl, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + await self._redis_client.expire(queue_key, key_default_ttl) + + +class PluggableSenderAdapter(ImpressionsSenderAdapter): + """Pluggable Impressions Sender Adapter class.""" + + def __init__(self, adapter_client, prefix=None): + """ + Initialize pluggable sender adapter instance + + :param telemtry_http_client: instance of telemetry http api + :type telemtry_http_client: splitio.api.telemetry.TelemetryAPI + """ + self._adapter_client = adapter_client + self._prefix = "" + if prefix is not None: + self._prefix = prefix + "." + + def record_unique_keys(self, uniques): + """ + post the unique keys to storage. + + :param uniques: unique keys disctionary + :type uniques: Dictionary {'feature_flag1': set(), 'feature_flag2': set(), .. } + """ + if len(uniques) == 0: + return + + bulk_mtks = _uniques_formatter(uniques) + try: + inserted = self._adapter_client.push_items(self._prefix + _MTK_QUEUE_KEY, *bulk_mtks) + self._expire_keys(self._prefix + _MTK_QUEUE_KEY, _MTK_KEY_DEFAULT_TTL, inserted, len(bulk_mtks)) + return True + + except RedisAdapterException: + _LOGGER.error('Something went wrong when trying to add mtks to storage adapter') + _LOGGER.error('Error: ', exc_info=True) + return False + + def flush_counters(self, to_send): + """ + post the impression counters to storage. + + :param to_send: unique keys disctionary + :type to_send: Dictionary {'feature_flag1': set(), 'feature_flag2': set(), .. } + """ + if len(to_send) == 0: + return + + try: + resulted = 0 + for pf_count in to_send: + key = self._prefix + _IMP_COUNT_QUEUE_KEY + "." + pf_count.feature + "::" + str(pf_count.timeframe) + resulted = self._adapter_client.increment(key, pf_count.count) + self._expire_keys(key, _IMP_COUNT_KEY_DEFAULT_TTL, resulted, pf_count.count) + return True + + except RedisAdapterException: + _LOGGER.error('Something went wrong when trying to add counters to storage adapter') + _LOGGER.error('Error: ', exc_info=True) + return False + + def _expire_keys(self, queue_key, key_default_ttl, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + self._adapter_client.expire(queue_key, key_default_ttl) + + +class PluggableSenderAdapterAsync(ImpressionsSenderAdapter): + """Pluggable Impressions Sender Adapter class.""" + + def __init__(self, adapter_client, prefix=None): + """ + Initialize pluggable sender adapter instance + + :param telemtry_http_client: instance of telemetry http api + :type telemtry_http_client: splitio.api.telemetry.TelemetryAPI + """ + self._adapter_client = adapter_client + self._prefix = "" + if prefix is not None: + self._prefix = prefix + "." + + async def record_unique_keys(self, uniques): + """ + post the unique keys to storage. + + :param uniques: unique keys disctionary + :type uniques: Dictionary {'feature_flag1': set(), 'feature_flag2': set(), .. } + """ + if len(uniques) == 0: + return True + + bulk_mtks = _uniques_formatter(uniques) + try: + inserted = await self._adapter_client.push_items(self._prefix + _MTK_QUEUE_KEY, *bulk_mtks) + await self._expire_keys(self._prefix + _MTK_QUEUE_KEY, _MTK_KEY_DEFAULT_TTL, inserted, len(bulk_mtks)) + return True + + except RedisAdapterException: + _LOGGER.error('Something went wrong when trying to add mtks to storage adapter') + _LOGGER.error('Error: ', exc_info=True) + return False + + async def flush_counters(self, to_send): + """ + post the impression counters to storage. + + :param to_send: unique keys disctionary + :type to_send: Dictionary {'feature_flag1': set(), 'feature_flag2': set(), .. } + """ + if len(to_send) == 0: + return True + + try: + resulted = 0 + for pf_count in to_send: + key = self._prefix + _IMP_COUNT_QUEUE_KEY + "." + pf_count.feature + "::" + str(pf_count.timeframe) + resulted = await self._adapter_client.increment(key, pf_count.count) + await self._expire_keys(key, _IMP_COUNT_KEY_DEFAULT_TTL, resulted, pf_count.count) + return True + + except RedisAdapterException: + _LOGGER.error('Something went wrong when trying to add counters to storage adapter') + _LOGGER.error('Error: ', exc_info=True) + return False + + async def _expire_keys(self, queue_key, key_default_ttl, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + await self._adapter_client.expire(queue_key, key_default_ttl) + +def _uniques_formatter(uniques): + """ + Format the unique keys dictionary array to a JSON body + + :param uniques: unique keys disctionary + :type uniques: Dictionary {'feature_flag1': set(), 'feature_flag2': set(), .. } + + :return: unique keys JSON array + :rtype: json + """ + return [json.dumps({'f': feature, 'ks': list(keys)}) for feature, keys in uniques.items()] diff --git a/splitio/engine/impressions/impressions.py b/splitio/engine/impressions/impressions.py new file mode 100644 index 00000000..428fdd13 --- /dev/null +++ b/splitio/engine/impressions/impressions.py @@ -0,0 +1,55 @@ +"""Split evaluator module.""" +from enum import Enum + +class ImpressionsMode(Enum): + """Impressions tracking mode.""" + + OPTIMIZED = "OPTIMIZED" + DEBUG = "DEBUG" + NONE = "NONE" + +class Manager(object): # pylint:disable=too-few-public-methods + """Impression manager.""" + + def __init__(self, strategy, none_strategy, telemetry_runtime_producer): + """ + Construct a manger to track and forward impressions to the queue. + + :param listener: Optional impressions listener that will capture all seen impressions. + :type listener: splitio.client.listener.ImpressionListenerWrapper + + :param strategy: Impressions stragetgy instance + :type strategy: (BaseStrategy) + """ + + self._strategy = strategy + self._none_strategy = none_strategy + self._telemetry_runtime_producer = telemetry_runtime_producer + + def process_impressions(self, impressions_decorated): + """ + Process impressions. + + Impressions are analyzed to see if they've been seen before and counted. + + :param impressions_decorated: List of impression objects with attributes + :type impressions_decorated: list[tuple[splitio.models.impression.ImpressionDecorated, dict]] + + :return: processed and deduped impressions. + :rtype: tuple(list[tuple[splitio.models.impression.Impression, dict]], list(int)) + """ + for_listener_all = [] + for_log_all = [] + for_counter_all = [] + for_unique_keys_tracker_all = [] + for impression_decorated, att in impressions_decorated: + if impression_decorated.disabled: + for_log, for_listener, for_counter, for_unique_keys_tracker = self._none_strategy.process_impressions([(impression_decorated.Impression, att)]) + else: + for_log, for_listener, for_counter, for_unique_keys_tracker = self._strategy.process_impressions([(impression_decorated.Impression, att)]) + for_listener_all.extend(for_listener) + for_log_all.extend(for_log) + for_counter_all.extend(for_counter) + for_unique_keys_tracker_all.extend(for_unique_keys_tracker) + + return for_log_all, len(impressions_decorated) - len(for_log_all), for_listener_all, for_counter_all, for_unique_keys_tracker_all diff --git a/splitio/engine/impressions.py b/splitio/engine/impressions/manager.py similarity index 60% rename from splitio/engine/impressions.py rename to splitio/engine/impressions/manager.py index c8720b5d..56727fd0 100644 --- a/splitio/engine/impressions.py +++ b/splitio/engine/impressions/manager.py @@ -1,25 +1,13 @@ -"""Split evaluator module.""" import threading from collections import defaultdict, namedtuple -from enum import Enum +from splitio.util.time import utctime_ms from splitio.models.impressions import Impression from splitio.engine.hashfns import murmur_128 from splitio.engine.cache.lru import SimpleLruCache -from splitio.client.listener import ImpressionListenerException -from splitio import util - +from splitio.optional.loaders import asyncio _TIME_INTERVAL_MS = 3600 * 1000 # one hour -_IMPRESSION_OBSERVER_CACHE_SIZE = 500000 - - -class ImpressionsMode(Enum): - """Impressions tracking mode.""" - - OPTIMIZED = "OPTIMIZED" - DEBUG = "DEBUG" - def truncate_time(timestamp_ms): """ @@ -33,6 +21,22 @@ def truncate_time(timestamp_ms): """ return timestamp_ms - (timestamp_ms % _TIME_INTERVAL_MS) +def truncate_impressions_time(imps, counter = None): + """ + Process impressions. + + Impressions are truncated based on time + + :param impressions: List of impression objects with attributes + :type impressions: list[tuple[splitio.models.impression.Impression, dict]] + + :returns: truncated list of impressions + :rtype: list[splitio.models.impression.Impression] + """ + this_hour = truncate_time(utctime_ms()) + return [imp for imp, _ in imps] if counter is None \ + else [i for i, _ in imps if i.previous_time is None or i.previous_time < this_hour] + class Hasher(object): # pylint:disable=too-few-public-methods """Impression hasher.""" @@ -149,70 +153,3 @@ def pop_all(self): return [Counter.CountPerFeature(k.feature, k.timeframe, v) for (k, v) in old.items()] - - -class Manager(object): # pylint:disable=too-few-public-methods - """Impression manager.""" - - def __init__(self, mode=ImpressionsMode.OPTIMIZED, standalone=True, listener=None): - """ - Construct a manger to track and forward impressions to the queue. - - :param mode: Impressions capturing mode. - :type mode: ImpressionsMode - - :param standalone: whether the SDK is running in standalone sending impressions by itself - :type standalone: bool - - :param listener: Optional impressions listener that will capture all seen impressions. - :type listener: splitio.client.listener.ImpressionListenerWrapper - """ - self._observer = Observer(_IMPRESSION_OBSERVER_CACHE_SIZE) if standalone else None - self._counter = Counter() if standalone and mode == ImpressionsMode.OPTIMIZED else None - self._listener = listener - - def process_impressions(self, impressions): - """ - Process impressions. - - Impressions are analyzed to see if they've been seen before and counted. - - :param impressions: List of impression objects with attributes - :type impressions: list[tuple[splitio.models.impression.Impression, dict]] - """ - imps = [(self._observer.test_and_set(imp), attrs) for imp, attrs in impressions] \ - if self._observer else impressions - - if self._counter: - self._counter.track([imp for imp, _ in imps]) - - self._send_impressions_to_listener(imps) - - this_hour = truncate_time(util.utctime_ms()) - return [imp for imp, _ in imps] if self._counter is None \ - else [i for i, _ in imps if i.previous_time is None or i.previous_time < this_hour] - - def get_counts(self): - """ - Return counts of impressions per features. - - :returns: A list of counter objects. - :rtype: list[Counter.CountPerFeature] - """ - return self._counter.pop_all() if self._counter is not None else [] - - def _send_impressions_to_listener(self, impressions): - """ - Send impression result to custom listener. - - :param impressions: List of impression objects with attributes - :type impressions: list[tuple[splitio.models.impression.Impression, dict]] - """ - if self._listener is not None: - try: - for impression, attributes in impressions: - self._listener.log_impression(impression, attributes) - except ImpressionListenerException: - pass -# self._logger.error('An exception was raised while calling user-custom impression listener') -# self._logger.debug('Error', exc_info=True) diff --git a/splitio/engine/impressions/strategies.py b/splitio/engine/impressions/strategies.py new file mode 100644 index 00000000..c2b0c565 --- /dev/null +++ b/splitio/engine/impressions/strategies.py @@ -0,0 +1,105 @@ +import abc + +from splitio.engine.impressions.manager import Observer, truncate_time +from splitio.util.time import utctime_ms + +_IMPRESSION_OBSERVER_CACHE_SIZE = 500000 + +class BaseStrategy(object, metaclass=abc.ABCMeta): + """Strategy interface.""" + + @abc.abstractmethod + def process_impressions(self): + """ + Return a list(impressions) object + + """ + pass + +class StrategyDebugMode(BaseStrategy): + """Debug mode strategy.""" + + def __init__(self): + """ + Construct a strategy instance for debug mode. + + """ + self._observer = Observer(_IMPRESSION_OBSERVER_CACHE_SIZE) + + def process_impressions(self, impressions): + """ + Process impressions. + + Impressions are analyzed to see if they've been seen before. + + :param impressions: List of impression objects with attributes + :type impressions: list[tuple[splitio.models.impression.Impression, dict]] + + :returns: Tuple of to be stored, observed and counted impressions, and unique keys tuple + :rtype: list[tuple[splitio.models.impression.Impression, dict]], list[], list[], list[] + """ + imps = [] + for imp, attrs in impressions: + if imp.properties is not None: + imps.append((imp, attrs)) + continue + + imps.append((self._observer.test_and_set(imp), attrs)) + + return [i for i, _ in imps], imps, [], [] + +class StrategyNoneMode(BaseStrategy): + """Debug mode strategy.""" + + def process_impressions(self, impressions): + """ + Process impressions. + + Impressions are analyzed to see if they've been seen before and counted. + Unique keys tracking are updated. + + :param impressions: List of impression objects with attributes + :type impressions: list[tuple[splitio.models.impression.Impression, dict]] + + :returns: Tuple of to be stored, observed and counted impressions, and unique keys tuple + :rtype: list[[], dict]], list[splitio.models.impression.Impression], list[splitio.models.impression.Impression], list[(str, str)] + """ + counter_imps = [imp for imp, _ in impressions] + unique_keys_tracker = [] + for i, _ in impressions: + unique_keys_tracker.append((i.matching_key, i.feature_name)) + return [], impressions, counter_imps, unique_keys_tracker + +class StrategyOptimizedMode(BaseStrategy): + """Optimized mode strategy.""" + + def __init__(self): + """ + Construct a strategy instance for optimized mode. + + """ + self._observer = Observer(_IMPRESSION_OBSERVER_CACHE_SIZE) + + def process_impressions(self, impressions): + """ + Process impressions. + + Impressions are analyzed to see if they've been seen before and counted. + + :param impressions: List of impression objects with attributes + :type impressions: list[tuple[splitio.models.impression.Impression, dict]] + + :returns: Tuple of to be stored, observed and counted impressions, and unique keys tuple + :rtype: list[tuple[splitio.models.impression.Impression, dict]], list[splitio.models.impression.Impression], list[splitio.models.impression.Impression], list[] + """ + imps = [] + for imp, attrs in impressions: + if imp.properties is not None: + imps.append((imp, attrs)) + continue + + imps.append((self._observer.test_and_set(imp), attrs)) + + counter_imps = [imp for imp, _ in imps if imp.previous_time != None] + this_hour = truncate_time(utctime_ms()) + return [i for i, _ in imps if i.previous_time is None or i.previous_time < this_hour], imps, counter_imps, [] diff --git a/splitio/engine/impressions/unique_keys_tracker.py b/splitio/engine/impressions/unique_keys_tracker.py new file mode 100644 index 00000000..4e8da012 --- /dev/null +++ b/splitio/engine/impressions/unique_keys_tracker.py @@ -0,0 +1,168 @@ +import abc +import threading +import logging + +from splitio.engine.filters import BloomFilter +from splitio.optional.loaders import asyncio + +_LOGGER = logging.getLogger(__name__) + +class UniqueKeysTrackerBase(object, metaclass=abc.ABCMeta): + """Unique Keys Tracker base class.""" + + @abc.abstractmethod + def track(self, key, feature_flag_name): + """ + Return a boolean flag + """ + pass + + def set_queue_full_hook(self, hook): + """ + Set a hook to be called when the queue is full. + + :param h: Hook to be called when the queue is full + """ + if callable(hook): + self._queue_full_hook = hook + + def _add_or_update(self, feature_flag_name, key): + """ + Add the feature_name+key to both bloom filter and dictionary. + + :param feature_flag_name: feature flag name associated with the key + :type feature_flag_name: str + :param key: key to be added to MTK list + :type key: int + """ + if feature_flag_name not in self._cache: + self._cache[feature_flag_name] = set() + self._cache[feature_flag_name].add(key) + + +class UniqueKeysTracker(UniqueKeysTrackerBase): + """Unique Keys Tracker class.""" + + def __init__(self, cache_size=30000): + """ + Initialize unique keys tracker instance + + :param cache_size: The size of the unique keys dictionary + :type key: int + """ + self._cache_size = cache_size + self._filter = BloomFilter(cache_size) + self._lock = threading.RLock() + self._cache = {} + self._queue_full_hook = None + self._current_cache_size = 0 + + def track(self, key, feature_flag_name): + """ + Return a boolean flag + + :param key: key to be added to MTK list + :type key: int + :param feature_flag_name: feature flag name associated with the key + :type feature_flag_name: str + + :return: True if successful + :rtype: boolean + """ + with self._lock: + if self._filter.contains(feature_flag_name+key): + return False + + self._add_or_update(feature_flag_name, key) + self._filter.add(feature_flag_name+key) + self._current_cache_size += 1 + + if self._current_cache_size > self._cache_size: + _LOGGER.info( + 'Unique Keys queue is full, flushing the current queue now.' + ) + if self._queue_full_hook is not None and callable(self._queue_full_hook): + _LOGGER.info('Calling hook.') + self._queue_full_hook() + return True + + def clear_filter(self): + """ + Delete the filter items + + """ + with self._lock: + self._filter.clear() + + def get_cache_info_and_pop_all(self): + with self._lock: + temp_cach = self._cache + temp_cache_size = self._current_cache_size + self._cache = {} + self._current_cache_size = 0 + + return temp_cach, temp_cache_size + + +class UniqueKeysTrackerAsync(UniqueKeysTrackerBase): + """Unique Keys Tracker async class.""" + + def __init__(self, cache_size=30000): + """ + Initialize unique keys tracker instance + + :param cache_size: The size of the unique keys dictionary + :type key: int + """ + self._cache_size = cache_size + self._filter = BloomFilter(cache_size) + self._lock = asyncio.Lock() + self._cache = {} + self._queue_full_hook = None + self._current_cache_size = 0 + + async def track(self, key, feature_flag_name): + """ + Return a boolean flag + + :param key: key to be added to MTK list + :type key: int + :param feature_flag_name: feature flag name associated with the key + :type feature_flag_name: str + + :return: True if successful + :rtype: boolean + """ + async with self._lock: + if self._filter.contains(feature_flag_name+key): + return False + + self._add_or_update(feature_flag_name, key) + self._filter.add(feature_flag_name+key) + self._current_cache_size += 1 + + if self._current_cache_size > self._cache_size: + _LOGGER.info( + 'Unique Keys queue is full, flushing the current queue now.' + ) + if self._queue_full_hook is not None and callable(self._queue_full_hook): + _LOGGER.info('Calling hook.') + await self._queue_full_hook() + return True + + async def clear_filter(self): + """ + Delete the filter items + + """ + async with self._lock: + self._filter.clear() + + async def get_cache_info_and_pop_all(self): + async with self._lock: + temp_cach = self._cache + temp_cache_size = self._current_cache_size + self._cache = {} + self._current_cache_size = 0 + + return temp_cach, temp_cache_size \ No newline at end of file diff --git a/splitio/engine/telemetry.py b/splitio/engine/telemetry.py new file mode 100644 index 00000000..f3bbba53 --- /dev/null +++ b/splitio/engine/telemetry.py @@ -0,0 +1,687 @@ +"""Telemetry engine classes.""" +import json +import os + +import logging +_LOGGER = logging.getLogger(__name__) + +from splitio.models.telemetry import CounterConstants, UpdateFromSSE + +class TelemetryStorageProducerBase(object): + """Telemetry storage producer base class.""" + + def get_telemetry_init_producer(self): + """get init producer instance.""" + return self._telemetry_init_producer + + def get_telemetry_evaluation_producer(self): + """get evaluation producer instance.""" + return self._telemetry_evaluation_producer + + def get_telemetry_runtime_producer(self): + """get runtime producer instance.""" + return self._telemetry_runtime_producer + + +class TelemetryStorageProducer(TelemetryStorageProducerBase): + """Telemetry storage producer class.""" + + def __init__(self, telemetry_storage): + """Initialize all producer classes.""" + self._telemetry_init_producer = TelemetryInitProducer(telemetry_storage) + self._telemetry_evaluation_producer = TelemetryEvaluationProducer(telemetry_storage) + self._telemetry_runtime_producer = TelemetryRuntimeProducer(telemetry_storage) + + +class TelemetryStorageProducerAsync(TelemetryStorageProducerBase): + """Telemetry storage producer class.""" + + def __init__(self, telemetry_storage): + """Initialize all producer classes.""" + self._telemetry_init_producer = TelemetryInitProducerAsync(telemetry_storage) + self._telemetry_evaluation_producer = TelemetryEvaluationProducerAsync(telemetry_storage) + self._telemetry_runtime_producer = TelemetryRuntimeProducerAsync(telemetry_storage) + +class TelemetryInitProducerBase(object): + """Telemetry init producer base class.""" + + def _get_app_worker_id(self): + try: + import uwsgi + return "uwsgi", str(uwsgi.worker_id()) + + except ModuleNotFoundError: + _LOGGER.debug("NO uwsgi") + pass + + if 'gunicorn' in os.environ.get("SERVER_SOFTWARE", ""): + return "gunicorn", str(os.getpid()) + + else: + return None, None + + +class TelemetryInitProducer(TelemetryInitProducerBase): + """Telemetry init producer class.""" + + def __init__(self, telemetry_storage): + """Constructor.""" + self._telemetry_storage = telemetry_storage + + def record_config(self, config, extra_config, total_flag_sets=0, invalid_flag_sets=0): + """Record configurations.""" + self._telemetry_storage.record_config(config, extra_config, total_flag_sets, invalid_flag_sets) + current_app, app_worker_id = self._get_app_worker_id() + if current_app is not None: + self.add_config_tag("initilization:" + current_app) + self.add_config_tag("worker:#" + app_worker_id) + + def record_ready_time(self, ready_time): + """Record ready time.""" + self._telemetry_storage.record_ready_time(ready_time) + + def record_flag_sets(self, flag_sets): + """Record flag sets.""" + self._telemetry_storage.record_flag_sets(flag_sets) + + def record_invalid_flag_sets(self, flag_sets): + """Record invalid flag sets.""" + self._telemetry_storage.record_invalid_flag_sets(flag_sets) + + def record_bur_time_out(self): + """Record block until ready timeout.""" + self._telemetry_storage.record_bur_time_out() + + def record_not_ready_usage(self): + """record non-ready usage.""" + self._telemetry_storage.record_not_ready_usage() + + def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """Record active and redundant factories.""" + self._telemetry_storage.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + + def add_config_tag(self, tag): + """Record tag string.""" + self._telemetry_storage.add_config_tag(tag) + + +class TelemetryInitProducerAsync(TelemetryInitProducerBase): + """Telemetry init producer async class.""" + + def __init__(self, telemetry_storage): + """Constructor.""" + self._telemetry_storage = telemetry_storage + + async def record_config(self, config, extra_config, total_flag_sets=0, invalid_flag_sets=0): + """Record configurations.""" + await self._telemetry_storage.record_config(config, extra_config, total_flag_sets, invalid_flag_sets) + current_app, app_worker_id = self._get_app_worker_id() + if current_app is not None: + await self.add_config_tag("initilization:" + current_app) + await self.add_config_tag("worker:#" + app_worker_id) + + async def record_ready_time(self, ready_time): + """Record ready time.""" + await self._telemetry_storage.record_ready_time(ready_time) + + async def record_flag_sets(self, flag_sets): + """Record flag sets.""" + await self._telemetry_storage.record_flag_sets(flag_sets) + + async def record_invalid_flag_sets(self, flag_sets): + """Record invalid flag sets.""" + await self._telemetry_storage.record_invalid_flag_sets(flag_sets) + + async def record_bur_time_out(self): + """Record block until ready timeout.""" + await self._telemetry_storage.record_bur_time_out() + + async def record_not_ready_usage(self): + """record non-ready usage.""" + await self._telemetry_storage.record_not_ready_usage() + + async def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """Record active and redundant factories.""" + await self._telemetry_storage.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + + async def add_config_tag(self, tag): + """Record tag string.""" + await self._telemetry_storage.add_config_tag(tag) + + +class TelemetryEvaluationProducer(object): + """Telemetry evaluation producer class.""" + + def __init__(self, telemetry_storage): + """Constructor.""" + self._telemetry_storage = telemetry_storage + + def record_latency(self, method, latency): + """Record method latency time.""" + self._telemetry_storage.record_latency(method, latency) + + def record_exception(self, method): + """Record method exception time.""" + self._telemetry_storage.record_exception(method) + + +class TelemetryEvaluationProducerAsync(object): + """Telemetry evaluation producer async class.""" + + def __init__(self, telemetry_storage): + """Constructor.""" + self._telemetry_storage = telemetry_storage + + async def record_latency(self, method, latency): + """Record method latency time.""" + await self._telemetry_storage.record_latency(method, latency) + + async def record_exception(self, method): + """Record method exception time.""" + await self._telemetry_storage.record_exception(method) + + +class TelemetryRuntimeProducer(object): + """Telemetry runtime producer class.""" + + def __init__(self, telemetry_storage): + """Constructor.""" + self._telemetry_storage = telemetry_storage + + def add_tag(self, tag): + """Record tag string.""" + self._telemetry_storage.add_tag(tag) + + def record_impression_stats(self, data_type, count): + """Record impressions stats.""" + self._telemetry_storage.record_impression_stats(data_type, count) + + def record_event_stats(self, data_type, count): + """Record events stats.""" + self._telemetry_storage.record_event_stats(data_type, count) + + def record_successful_sync(self, resource, time): + """Record successful sync.""" + self._telemetry_storage.record_successful_sync(resource, time) + + def record_sync_error(self, resource, status): + """Record sync error.""" + self._telemetry_storage.record_sync_error(resource, status) + + def record_sync_latency(self, resource, latency): + """Record latency time.""" + self._telemetry_storage.record_sync_latency(resource, latency) + + def record_auth_rejections(self): + """Record auth rejection.""" + self._telemetry_storage.record_auth_rejections() + + def record_token_refreshes(self): + """Record sse token refresh.""" + self._telemetry_storage.record_token_refreshes() + + def record_streaming_event(self, streaming_event): + """Record incoming streaming event.""" + self._telemetry_storage.record_streaming_event(streaming_event) + + def record_session_length(self, session): + """Record session length.""" + self._telemetry_storage.record_session_length(session) + + def record_update_from_sse(self, event): + """Record update from sse.""" + self._telemetry_storage.record_update_from_sse(event) + +class TelemetryRuntimeProducerAsync(object): + """Telemetry runtime producer async class.""" + + def __init__(self, telemetry_storage): + """Constructor.""" + self._telemetry_storage = telemetry_storage + + async def add_tag(self, tag): + """Record tag string.""" + await self._telemetry_storage.add_tag(tag) + + async def record_impression_stats(self, data_type, count): + """Record impressions stats.""" + await self._telemetry_storage.record_impression_stats(data_type, count) + + async def record_event_stats(self, data_type, count): + """Record events stats.""" + await self._telemetry_storage.record_event_stats(data_type, count) + + async def record_successful_sync(self, resource, time): + """Record successful sync.""" + await self._telemetry_storage.record_successful_sync(resource, time) + + async def record_sync_error(self, resource, status): + """Record sync error.""" + await self._telemetry_storage.record_sync_error(resource, status) + + async def record_sync_latency(self, resource, latency): + """Record latency time.""" + await self._telemetry_storage.record_sync_latency(resource, latency) + + async def record_auth_rejections(self): + """Record auth rejection.""" + await self._telemetry_storage.record_auth_rejections() + + async def record_token_refreshes(self): + """Record sse token refresh.""" + await self._telemetry_storage.record_token_refreshes() + + async def record_streaming_event(self, streaming_event): + """Record incoming streaming event.""" + await self._telemetry_storage.record_streaming_event(streaming_event) + + async def record_session_length(self, session): + """Record session length.""" + await self._telemetry_storage.record_session_length(session) + + async def record_update_from_sse(self, event): + """Record update from sse.""" + await self._telemetry_storage.record_update_from_sse(event) + +class TelemetryStorageConsumerBase(object): + """Telemetry storage consumer base class.""" + + def get_telemetry_init_consumer(self): + """Get telemetry init instance""" + return self._telemetry_init_consumer + + def get_telemetry_evaluation_consumer(self): + """Get telemetry evaluation instance""" + return self._telemetry_evaluation_consumer + + def get_telemetry_runtime_consumer(self): + """Get telemetry runtime instance""" + return self._telemetry_runtime_consumer + + +class TelemetryStorageConsumer(TelemetryStorageConsumerBase): + """Telemetry storage consumer class.""" + + def __init__(self, telemetry_storage): + """Initialize all consumer classes.""" + self._telemetry_init_consumer = TelemetryInitConsumer(telemetry_storage) + self._telemetry_evaluation_consumer = TelemetryEvaluationConsumer(telemetry_storage) + self._telemetry_runtime_consumer = TelemetryRuntimeConsumer(telemetry_storage) + + +class TelemetryStorageConsumerAsync(TelemetryStorageConsumerBase): + """Telemetry storage consumer async class.""" + + def __init__(self, telemetry_storage): + """Initialize all consumer classes.""" + self._telemetry_init_consumer = TelemetryInitConsumerAsync(telemetry_storage) + self._telemetry_evaluation_consumer = TelemetryEvaluationConsumerAsync(telemetry_storage) + self._telemetry_runtime_consumer = TelemetryRuntimeConsumerAsync(telemetry_storage) + + +class TelemetryInitConsumer(object): + """Telemetry init consumer class.""" + + def __init__(self, telemetry_storage): + """Constructor.""" + self._telemetry_storage = telemetry_storage + + def get_bur_time_outs(self): + """Get block until ready timeout.""" + return self._telemetry_storage.get_bur_time_outs() + + def get_not_ready_usage(self): + """Get none-ready usage.""" + return self._telemetry_storage.get_not_ready_usage() + + def get_config_stats(self): + """Get config stats.""" + config_stats = self._telemetry_storage.get_config_stats() + config_stats.update({'t': self.pop_config_tags()}) + return config_stats + + def get_config_stats_to_json(self): + """Get config stats in json.""" + return json.dumps(self._telemetry_storage.get_config_stats()) + + def pop_config_tags(self): + """Get and reset tags.""" + return self._telemetry_storage.pop_config_tags() + + +class TelemetryInitConsumerAsync(object): + """Telemetry init consumer class.""" + + def __init__(self, telemetry_storage): + """Constructor.""" + self._telemetry_storage = telemetry_storage + + async def get_bur_time_outs(self): + """Get block until ready timeout.""" + return await self._telemetry_storage.get_bur_time_outs() + + async def get_not_ready_usage(self): + """Get none-ready usage.""" + return await self._telemetry_storage.get_not_ready_usage() + + async def get_config_stats(self): + """Get config stats.""" + config_stats = await self._telemetry_storage.get_config_stats() + config_stats.update({'t': await self.pop_config_tags()}) + return config_stats + + async def get_config_stats_to_json(self): + """Get config stats in json.""" + return json.dumps(await self._telemetry_storage.get_config_stats()) + + async def pop_config_tags(self): + """Get and reset tags.""" + return await self._telemetry_storage.pop_config_tags() + + +class TelemetryEvaluationConsumerBase(object): + """Telemetry evaluation consumer base class.""" + + def _to_json(self, exceptions, latencies): + """Return json formatted stats""" + return { + 'mE': {'t': exceptions['treatment'], + 'ts': exceptions['treatments'], + 'tc': exceptions['treatment_with_config'], + 'tcs': exceptions['treatments_with_config'], + 'tf': exceptions['treatments_by_flag_set'], + 'tfs': exceptions['treatments_by_flag_sets'], + 'tcf': exceptions['treatments_with_config_by_flag_set'], + 'tcfs': exceptions['treatments_with_config_by_flag_sets'], + 'tr': exceptions['track'] + }, + 'mL': {'t': latencies['treatment'], + 'ts': latencies['treatments'], + 'tc': latencies['treatment_with_config'], + 'tcs': latencies['treatments_with_config'], + 'tf': latencies['treatments_by_flag_set'], + 'tfs': latencies['treatments_by_flag_sets'], + 'tcf': latencies['treatments_with_config_by_flag_set'], + 'tcfs': latencies['treatments_with_config_by_flag_sets'], + 'tr': latencies['track'] + }, + } + + +class TelemetryEvaluationConsumer(TelemetryEvaluationConsumerBase): + """Telemetry evaluation consumer class.""" + + def __init__(self, telemetry_storage): + """Constructor.""" + self._telemetry_storage = telemetry_storage + + def pop_exceptions(self): + """Get and reset method exceptions.""" + return self._telemetry_storage.pop_exceptions() + + def pop_latencies(self): + """Get and reset eval latencies.""" + return self._telemetry_storage.pop_latencies() + + def pop_formatted_stats(self): + """ + Get formatted and reset stats. + + :returns: formatted stats + :rtype: Dict + """ + exceptions = self.pop_exceptions()['methodExceptions'] + latencies = self.pop_latencies()['methodLatencies'] + return self._to_json(exceptions, latencies) + + +class TelemetryEvaluationConsumerAsync(TelemetryEvaluationConsumerBase): + """Telemetry evaluation consumer async class.""" + + def __init__(self, telemetry_storage): + """Constructor.""" + self._telemetry_storage = telemetry_storage + + async def pop_exceptions(self): + """Get and reset method exceptions.""" + return await self._telemetry_storage.pop_exceptions() + + async def pop_latencies(self): + """Get and reset eval latencies.""" + return await self._telemetry_storage.pop_latencies() + + async def pop_formatted_stats(self): + """ + Get formatted and reset stats. + + :returns: formatted stats + :rtype: Dict + """ + exceptions = await self.pop_exceptions() + latencies = await self.pop_latencies() + return self._to_json(exceptions['methodExceptions'], latencies['methodLatencies']) + + +class TelemetryRuntimeConsumerBase(object): + """Telemetry runtime consumer base class.""" + + def _last_synchronization_to_json(self, last_synchronization): + """ + Get formatted last synchronization. + + :returns: formatted stats + :rtype: Dict + """ + return {'sp': last_synchronization['split'], + 'se': last_synchronization['segment'], + 'im': last_synchronization['impression'], + 'ic': last_synchronization['impressionCount'], + 'ev': last_synchronization['event'], + 'te': last_synchronization['telemetry'], + 'to': last_synchronization['token'] + } + + def _http_errors_to_json(self, http_errors): + """ + Get formatted http errors + + :returns: formatted stats + :rtype: Dict + """ + return {'sp': http_errors['split'], + 'se': http_errors['segment'], + 'im': http_errors['impression'], + 'ic': http_errors['impressionCount'], + 'ev': http_errors['event'], + 'te': http_errors['telemetry'], + 'to': http_errors['token'] + } + + def _http_latencies_to_json(self, http_latencies): + """ + Get formatted http latencies + + :returns: formatted stats + :rtype: Dict + """ + return {'sp': http_latencies['split'], + 'se': http_latencies['segment'], + 'im': http_latencies['impression'], + 'ic': http_latencies['impressionCount'], + 'ev': http_latencies['event'], + 'te': http_latencies['telemetry'], + 'to': http_latencies['token'] + } + + def _streaming_events_to_json(self, streaming_events): + """ + Get formatted http latencies + + :returns: formatted stats + :rtype: Dict + """ + return [{'e': event['e'], + 'd': event['d'], + 't': event['t'] + } for event in streaming_events['streamingEvents']] + + +class TelemetryRuntimeConsumer(TelemetryRuntimeConsumerBase): + """Telemetry runtime consumer class.""" + + def __init__(self, telemetry_storage): + """Constructor.""" + self._telemetry_storage = telemetry_storage + + def get_impressions_stats(self, type): + """Get impressions stats""" + return self._telemetry_storage.get_impressions_stats(type) + + def get_events_stats(self, type): + """Get events stats""" + return self._telemetry_storage.get_events_stats(type) + + def get_last_synchronization(self): + """Get last sync""" + return self._telemetry_storage.get_last_synchronization()['lastSynchronizations'] + + def pop_tags(self): + """Get and reset tags.""" + return self._telemetry_storage.pop_tags() + + def pop_http_errors(self): + """Get and reset http errors.""" + return self._telemetry_storage.pop_http_errors() + + def pop_http_latencies(self): + """Get and reset http latencies.""" + return self._telemetry_storage.pop_http_latencies() + + def pop_auth_rejections(self): + """Get and reset auth rejections.""" + return self._telemetry_storage.pop_auth_rejections() + + def pop_token_refreshes(self): + """Get and reset token refreshes.""" + return self._telemetry_storage.pop_token_refreshes() + + def pop_streaming_events(self): + """Get and reset streaming events.""" + return self._telemetry_storage.pop_streaming_events() + + def pop_update_from_sse(self, event): + """Get and reset update from sse.""" + return self._telemetry_storage.pop_update_from_sse(event) + + def get_session_length(self): + """Get session length""" + return self._telemetry_storage.get_session_length() + + def pop_formatted_stats(self): + """ + Get formatted and reset stats. + + :returns: formatted stats + :rtype: Dict + """ + last_synchronization = self.get_last_synchronization() + http_errors = self.pop_http_errors()['httpErrors'] + http_latencies = self.pop_http_latencies()['httpLatencies'] + + return { + 'iQ': self.get_impressions_stats(CounterConstants.IMPRESSIONS_QUEUED), + 'iDe': self.get_impressions_stats(CounterConstants.IMPRESSIONS_DEDUPED), + 'iDr': self.get_impressions_stats(CounterConstants.IMPRESSIONS_DROPPED), + 'eQ': self.get_events_stats(CounterConstants.EVENTS_QUEUED), + 'eD': self.get_events_stats(CounterConstants.EVENTS_DROPPED), + 'lS': self._last_synchronization_to_json(last_synchronization), + 'ufs': {event.value: self.pop_update_from_sse(event) for event in UpdateFromSSE}, + 't': self.pop_tags(), + 'hE': self._http_errors_to_json(http_errors), + 'hL': self._http_latencies_to_json(http_latencies), + 'aR': self.pop_auth_rejections(), + 'tR': self.pop_token_refreshes(), + 'sE': self._streaming_events_to_json(self.pop_streaming_events()), + 'sL': self.get_session_length() + } + + +class TelemetryRuntimeConsumerAsync(TelemetryRuntimeConsumerBase): + """Telemetry runtime consumer class.""" + + def __init__(self, telemetry_storage): + """Constructor.""" + self._telemetry_storage = telemetry_storage + + async def get_impressions_stats(self, type): + """Get impressions stats""" + return await self._telemetry_storage.get_impressions_stats(type) + + async def get_events_stats(self, type): + """Get events stats""" + return await self._telemetry_storage.get_events_stats(type) + + async def get_last_synchronization(self): + """Get last sync""" + last_sync = await self._telemetry_storage.get_last_synchronization() + return last_sync['lastSynchronizations'] + + async def pop_tags(self): + """Get and reset tags.""" + return await self._telemetry_storage.pop_tags() + + async def pop_http_errors(self): + """Get and reset http errors.""" + return await self._telemetry_storage.pop_http_errors() + + async def pop_http_latencies(self): + """Get and reset http latencies.""" + return await self._telemetry_storage.pop_http_latencies() + + async def pop_auth_rejections(self): + """Get and reset auth rejections.""" + return await self._telemetry_storage.pop_auth_rejections() + + async def pop_token_refreshes(self): + """Get and reset token refreshes.""" + return await self._telemetry_storage.pop_token_refreshes() + + async def pop_streaming_events(self): + """Get and reset streaming events.""" + return await self._telemetry_storage.pop_streaming_events() + + async def pop_update_from_sse(self, event): + """Get and reset update from sse.""" + return await self._telemetry_storage.pop_update_from_sse(event) + + async def get_session_length(self): + """Get session length""" + return await self._telemetry_storage.get_session_length() + + async def pop_formatted_stats(self): + """ + Get formatted and reset stats. + + :returns: formatted stats + :rtype: Dict + """ + last_synchronization = await self.get_last_synchronization() + http_errors = await self.pop_http_errors() + http_latencies = await self.pop_http_latencies() + # TODO: if ufs value is too large, use gather to fetch events instead of serial style. + return { + 'iQ': await self.get_impressions_stats(CounterConstants.IMPRESSIONS_QUEUED), + 'iDe': await self.get_impressions_stats(CounterConstants.IMPRESSIONS_DEDUPED), + 'iDr': await self.get_impressions_stats(CounterConstants.IMPRESSIONS_DROPPED), + 'eQ': await self.get_events_stats(CounterConstants.EVENTS_QUEUED), + 'eD': await self.get_events_stats(CounterConstants.EVENTS_DROPPED), + 'ufs': {event.value: await self.pop_update_from_sse(event) for event in UpdateFromSSE}, + 'lS': self._last_synchronization_to_json(last_synchronization), + 't': await self.pop_tags(), + 'hE': self._http_errors_to_json(http_errors['httpErrors']), + 'hL': self._http_latencies_to_json(http_latencies['httpLatencies']), + 'aR': await self.pop_auth_rejections(), + 'tR': await self.pop_token_refreshes(), + 'sE': self._streaming_events_to_json(await self.pop_streaming_events()), + 'sL': await self.get_session_length() + } diff --git a/splitio/events/__init__.py b/splitio/events/__init__.py new file mode 100644 index 00000000..cee5543e --- /dev/null +++ b/splitio/events/__init__.py @@ -0,0 +1,25 @@ +"""Base storage interfaces.""" +import abc + +class EventsManagerInterface(object, metaclass=abc.ABCMeta): + """Events manager interface implemented as an abstract class.""" + + @abc.abstractmethod + def register(self, sdk_event, event_handler): + pass + + @abc.abstractmethod + def unregister(self, sdk_event): + pass + + @abc.abstractmethod + def notify_internal_event(self, sdk_internal_event, event_metadata): + pass + + +class EventsDeliveryInterface(object, metaclass=abc.ABCMeta): + """Events Delivery interface.""" + + @abc.abstractmethod + def deliver(self, sdk_event, event_metadata, event_handler): + pass \ No newline at end of file diff --git a/splitio/events/events_delivery.py b/splitio/events/events_delivery.py new file mode 100644 index 00000000..a582d8a0 --- /dev/null +++ b/splitio/events/events_delivery.py @@ -0,0 +1,28 @@ +"""Events Manager.""" +import logging + +from splitio.events import EventsDeliveryInterface + +_LOGGER = logging.getLogger(__name__) + +class EventsDelivery(EventsDeliveryInterface): + """Events Manager class.""" + + def __init__(self): + """ + Construct Events Manager instance. + """ + + def deliver(self, sdk_event, event_metadata, event_handler): + try: + event_handler(event_metadata) + except Exception as ex: + _LOGGER.error("Exception when calling handler for Sdk Event %s", sdk_event) + _LOGGER.error(ex) + + async def deliver_async(self, sdk_event, event_metadata, event_handler): + try: + await event_handler(event_metadata) + except Exception as ex: + _LOGGER.error("Exception when calling handler for Sdk Event %s", sdk_event) + _LOGGER.error(ex) diff --git a/splitio/events/events_manager.py b/splitio/events/events_manager.py new file mode 100644 index 00000000..de8206f1 --- /dev/null +++ b/splitio/events/events_manager.py @@ -0,0 +1,251 @@ +"""Events Manager.""" +import threading +import logging +from collections import namedtuple +from splitio.optional.loaders import asyncio + +from splitio.events import EventsManagerInterface +from splitio.models.events import SdkEvent + +_LOGGER = logging.getLogger(__name__) + +ValidSdkEvent = namedtuple('ValidSdkEvent', ['sdk_event', 'valid']) +ActiveSubscriptions = namedtuple('ActiveSubscriptions', ['triggered', 'handler']) + +class EventsManagerBase(EventsManagerInterface): + """Events Manager class.""" + + def __init__(self, events_configurations, events_delivery): + """ + Construct Events Manager instance. + """ + self._active_subscriptions = {} + self._internal_events_status = {} + self._events_delivery = events_delivery + self._manager_config = events_configurations + + def register(self, sdk_event, event_handler): + # Implement in child class + pass + + def unregister(self, sdk_event): + # Implement in child class + pass + + def notify_internal_event(self, sdk_internal_event, event_metadata): + # Implement in child class + pass + + def destroy(self): + # Implement in child class + pass + + def _event_already_triggered(self, sdk_event): + if self._active_subscriptions.get(sdk_event) != None: + return self._active_subscriptions.get(sdk_event).triggered + + return False + + def _get_internal_event_status(self, sdk_internal_event): + if self._internal_events_status.get(sdk_internal_event) != None: + return self._internal_events_status[sdk_internal_event] + + return False + + def _update_internal_event_status(self, sdk_internal_event, status): + self._internal_events_status[sdk_internal_event] = status + + def _set_sdk_event_triggered(self, sdk_event): + if self._active_subscriptions.get(sdk_event) == None: + return + + if self._active_subscriptions.get(sdk_event).triggered == True: + return + + self._active_subscriptions[sdk_event] = self._active_subscriptions[sdk_event]._replace(triggered = True) + + def _get_event_handler(self, sdk_event): + if self._active_subscriptions.get(sdk_event) == None: + return None + + return self._active_subscriptions.get(sdk_event).handler + + def _get_sdk_event_if_applicable(self, sdk_internal_event): + final_sdk_event = ValidSdkEvent(None, False) + + events_to_fire = [] + require_any_sdk_event = self._check_require_any(sdk_internal_event) + if require_any_sdk_event.valid: + if (not self._event_already_triggered(require_any_sdk_event.sdk_event) and + self._execution_limit(require_any_sdk_event.sdk_event) == 1) or \ + self._execution_limit(require_any_sdk_event.sdk_event) == -1: + final_sdk_event = final_sdk_event._replace(sdk_event = require_any_sdk_event.sdk_event, + valid = self._check_prerequisites(require_any_sdk_event.sdk_event) and \ + self._check_suppressed_by(require_any_sdk_event.sdk_event)) + + if final_sdk_event.valid: + events_to_fire.append(final_sdk_event.sdk_event) + + [events_to_fire.append(sdk_event) for sdk_event in self._check_require_all()] + + return events_to_fire + + def _check_require_all(self): + events = [] + for require_name, require_value in self._manager_config.require_all.items(): + final_status = True + for val in require_value: + final_status &= self._get_internal_event_status(val) + + if final_status and \ + self._check_prerequisites(require_name) and \ + ((not self._event_already_triggered(require_name) and + self._execution_limit(require_name) == 1) or \ + self._execution_limit(require_name) == -1) and \ + len(require_value) > 0: + + events.append(require_name) + + return events + + def _check_prerequisites(self, sdk_event): + for name, value in self._manager_config.prerequisites.items(): + for val in value: + if name == sdk_event and not self._event_already_triggered(val): + return False + + return True + + def _check_suppressed_by(self, sdk_event): + for name, value in self._manager_config.suppressed_by.items(): + for val in value: + if name == sdk_event and self._event_already_triggered(val): + return False + + return True + + def _execution_limit(self, sdk_event): + limit = self._manager_config.execution_limits.get(sdk_event) + if limit == None: + return -1 + + return limit + + def _check_require_any(self, sdk_internal_event): + valid_sdk_event = ValidSdkEvent(None, False) + for name, val in self._manager_config.require_any.items(): + if sdk_internal_event in val: + valid_sdk_event = valid_sdk_event._replace(valid = True, sdk_event = name) + return valid_sdk_event + + return valid_sdk_event + +class EventsManager(EventsManagerBase): + """Events Manager class.""" + + def __init__(self, events_configurations, events_delivery): + """ + Construct Events Manager instance. + """ + EventsManagerBase.__init__(self, events_configurations, events_delivery) + self._lock = threading.RLock() + + def register(self, sdk_event, event_handler): + if self._active_subscriptions.get(sdk_event) != None and self._get_event_handler(sdk_event) != None: + return + + with self._lock: + # SDK ready already fired + if sdk_event == SdkEvent.SDK_READY and self._event_already_triggered(sdk_event): + self._active_subscriptions[sdk_event] = ActiveSubscriptions(True, event_handler) + _LOGGER.debug("EventsManager: Firing SDK_READY event for new subscription") + self._fire_sdk_event(sdk_event, None) + return + + self._active_subscriptions[sdk_event] = ActiveSubscriptions(False, event_handler) + + def unregister(self, sdk_event): + if self._active_subscriptions.get(sdk_event) == None: + return + + with self._lock: + del self._active_subscriptions[sdk_event] + + def notify_internal_event(self, sdk_internal_event, event_metadata): + with self._lock: + self._update_internal_event_status(sdk_internal_event, True) + for sorted_event in self._manager_config.evaluation_order: + if sorted_event in self._get_sdk_event_if_applicable(sdk_internal_event): + if self._get_event_handler(sorted_event) != None: + self._fire_sdk_event(sorted_event, event_metadata) + + # if client is not subscribed to SDK_READY + if sorted_event == SdkEvent.SDK_READY and self._get_event_handler(sorted_event) == None: + _LOGGER.debug("EventsManager: Registering SDK_READY event as fired") + self._active_subscriptions[SdkEvent.SDK_READY] = ActiveSubscriptions(True, None) + + def destroy(self): + with self._lock: + self._active_subscriptions = {} + self._internal_events_status = {} + + def _fire_sdk_event(self, sdk_event, event_metadata): + _LOGGER.debug("EventsManager: Firing Sdk event %s", sdk_event) + notify_event = threading.Thread(target=self._events_delivery.deliver, args=[sdk_event, event_metadata, self._get_event_handler(sdk_event)], + name='SplitSDKEventNotify', daemon=True) + notify_event.start() + self._set_sdk_event_triggered(sdk_event) + +class EventsManagerAsync(EventsManagerBase): + """Events Manager Async class.""" + + def __init__(self, events_configurations, events_delivery): + """ + Construct Events Manager instance. + """ + EventsManagerBase.__init__(self, events_configurations, events_delivery) + self._lock = asyncio.Lock() + + async def register(self, sdk_event, event_handler): + if self._active_subscriptions.get(sdk_event) != None and self._get_event_handler(sdk_event) != None: + return + + async with self._lock: + # SDK ready already fired + if sdk_event == SdkEvent.SDK_READY and self._event_already_triggered(sdk_event): + self._active_subscriptions[sdk_event] = ActiveSubscriptions(True, event_handler) + _LOGGER.debug("EventsManager: Firing SDK_READY event for new subscription") + self._fire_sdk_event(sdk_event, None) + return + + self._active_subscriptions[sdk_event] = ActiveSubscriptions(False, event_handler) + + async def unregister(self, sdk_event): + if self._active_subscriptions.get(sdk_event) == None: + return + + async with self._lock: + del self._active_subscriptions[sdk_event] + + async def notify_internal_event(self, sdk_internal_event, event_metadata): + async with self._lock: + self._update_internal_event_status(sdk_internal_event, True) + for sorted_event in self._manager_config.evaluation_order: + if sorted_event in self._get_sdk_event_if_applicable(sdk_internal_event): + if self._get_event_handler(sorted_event) != None: + self._fire_sdk_event(sorted_event, event_metadata) + + # if client is not subscribed to SDK_READY + if sorted_event == SdkEvent.SDK_READY and self._get_event_handler(sorted_event) == None: + _LOGGER.debug("EventsManager: Registering SDK_READY event as fired") + self._active_subscriptions[SdkEvent.SDK_READY] = ActiveSubscriptions(True, None) + + async def destroy(self): + async with self._lock: + self._active_subscriptions = {} + self._internal_events_status = {} + + def _fire_sdk_event(self, sdk_event, event_metadata): + _LOGGER.debug("EventsManager: Firing Sdk event %s", sdk_event) + asyncio.get_running_loop().create_task(self._events_delivery.deliver_async(sdk_event, event_metadata, self._get_event_handler(sdk_event))) + self._set_sdk_event_triggered(sdk_event) \ No newline at end of file diff --git a/splitio/events/events_manager_config.py b/splitio/events/events_manager_config.py new file mode 100644 index 00000000..b987d380 --- /dev/null +++ b/splitio/events/events_manager_config.py @@ -0,0 +1,111 @@ +"""Events Manager Configuration.""" +from splitio.models.events import SdkEvent, SdkInternalEvent + +class EventsManagerConfig(object): + """Events Manager Configurations class.""" + + def __init__(self): + """ + Construct Events Manager Configuration instance. + """ + self._require_all = self._get_require_all() + self._prerequisites = self._get_prerequisites() + self._require_any = self._get_require_any() + self._suppressed_by = self._get_suppressed_by() + self._execution_limits = self._get_execution_limits() + self._evaluation_order = self._get_sorted_events() + + @property + def require_all(self): + """Return require all dict""" + return self._require_all + + @property + def prerequisites(self): + """Return prerequisites dict""" + return self._prerequisites + + @property + def require_any(self): + """Return require_any dict""" + return self._require_any + + @property + def suppressed_by(self): + """Return suppressed_by dict""" + return self._suppressed_by + + @property + def execution_limits(self): + """Return execution_limits dict""" + return self._execution_limits + + @property + def evaluation_order(self): + """Return evaluation_order dict""" + return self._evaluation_order + + def _get_require_all(self): + """Return require all dict""" + return { + SdkEvent.SDK_READY: {SdkInternalEvent.SDK_READY} + } + + def _get_prerequisites(self): + """Return prerequisites dict""" + return { + SdkEvent.SDK_UPDATE: {SdkEvent.SDK_READY} + } + + def _get_require_any(self): + """Return require_any dict""" + return { + SdkEvent.SDK_UPDATE: {SdkInternalEvent.FLAG_KILLED_NOTIFICATION, SdkInternalEvent.FLAGS_UPDATED, + SdkInternalEvent.RB_SEGMENTS_UPDATED, SdkInternalEvent.SEGMENTS_UPDATED} + } + + def _get_suppressed_by(self): + """Return suppressed_by dict""" + return { + } + + def _get_execution_limits(self): + """Return execution_limits dict""" + return { + SdkEvent.SDK_READY: 1, + SdkEvent.SDK_UPDATE: -1 + } + + def _get_sorted_events(self): + """Return dorted events set""" + sorted_events = [] + for sdk_event in [SdkEvent.SDK_READY, SdkEvent.SDK_UPDATE]: + sorted_events = self._dfs_recursive(sdk_event, sorted_events) + + return sorted_events + + + def _dfs_recursive(self, sdk_event, added): + """Return sorted events set based on the dependency rules""" + if sdk_event in added: + return added + + for dependent_event in self._get_dependencies(sdk_event): + added = self._dfs_recursive(dependent_event, added) + + added.append(sdk_event) + return added + + def _get_dependencies(self, sdk_event): + """Return dependencies set from prerequisites and suppressed events for a given event""" + dependencies = set() + for prerequisites_event_name, prerequisites_event_value in self.prerequisites.items(): + if prerequisites_event_name == sdk_event: + for prereq_event in prerequisites_event_value: + dependencies.add(prereq_event) + + for suppressed_event_name, suppressed_event_value in self.suppressed_by.items(): + if sdk_event in suppressed_event_value: + dependencies.add(suppressed_event_name) + + return dependencies diff --git a/splitio/events/events_metadata.py b/splitio/events/events_metadata.py new file mode 100644 index 00000000..0707a8f5 --- /dev/null +++ b/splitio/events/events_metadata.py @@ -0,0 +1,35 @@ +"""Events Metadata.""" +from enum import Enum + +class SdkEventType(Enum): + """Public event types""" + + FLAG_UPDATE = 'FLAG_UPDATE' + SEGMENTS_UPDATE = 'SEGMENTS_UPDATE' + +class EventsMetadata(object): + """Events Metadata class.""" + + def __init__(self, type, names): + """ + Construct Events Metadata instance. + """ + self._type = type + self._names = self._sanitize(names) + + def get_type(self): + """Return type""" + return self._type + + def get_names(self): + """Return names""" + return self._names + + def _sanitize(self, names): + """Return sanitized names list with values str""" + santized_data = set() + for name in names: + if isinstance(name, str): + santized_data.add(name) + + return santized_data diff --git a/splitio/events/events_task.py b/splitio/events/events_task.py new file mode 100644 index 00000000..8158dc04 --- /dev/null +++ b/splitio/events/events_task.py @@ -0,0 +1,146 @@ +"""sdk internal events task.""" +import logging +import threading +import abc + +from splitio.optional.loaders import asyncio + +_LOGGER = logging.getLogger(__name__) + +class EventsTaskBase(object, metaclass=abc.ABCMeta): + """task template.""" + + @abc.abstractmethod + def is_running(self): + """Return whether the task is running.""" + + @abc.abstractmethod + def start(self): + """Start task.""" + + @abc.abstractmethod + def stop(self): + """Stop task.""" + +class EventsTask(EventsTaskBase): + """sdk internal events processing task.""" + + _centinel = object() + + def __init__(self, notify_internal_events, internal_events_queue): + """ + Class constructor. + + :param synchronize_segment: handler to perform segment synchronization on incoming event + :type synchronize_segment: function + + :param segment_queue: queue with segment updates notifications + :type segment_queue: queue + """ + self._internal_events_queue = internal_events_queue + self._handler = notify_internal_events + self._running = False + self._worker = None + + def is_running(self): + """Return whether the working is running.""" + return self._running + + def _run(self): + """Run worker handler.""" + while self.is_running(): + event = self._internal_events_queue.get() + if not self.is_running(): + break + + if event == self._centinel: + continue + + _LOGGER.debug('Processing sdk internal event: %s', event.internal_event) + try: + self._handler(event.internal_event, event.metadata) + except Exception: + _LOGGER.error('Exception raised in events manager') + _LOGGER.debug('Exception information: ', exc_info=True) + + def start(self): + """Start worker.""" + if self.is_running(): + _LOGGER.debug('SDK Event Worker is already running') + return + + self._running = True + _LOGGER.debug('Starting SDK Event Task worker') + self._worker = threading.Thread(target=self._run, name='EventsTaskWorker', daemon=True) + self._worker.start() + + def stop(self, stop_flag=None): + """Stop worker.""" + _LOGGER.debug('Stopping SDK Event Task worker') + if not self.is_running(): + _LOGGER.debug('SDK Event Worker is not running. Ignoring.') + return + + self._running = False + self._internal_events_queue.put(self._centinel) + +class EventsTaskAsync(EventsTaskBase): + """sdk internal events processing task.""" + + _centinel = object() + + def __init__(self, notify_internal_events, internal_events_queue): + """ + Class constructor. + + :param synchronize_segment: handler to perform segment synchronization on incoming event + :type synchronize_segment: function + + :param segment_queue: queue with segment updates notifications + :type segment_queue: queue + """ + self._internal_events_queue = internal_events_queue + self._handler = notify_internal_events + self._running = False + self._worker = None + + def is_running(self): + """Return whether the working is running.""" + return self._running + + async def _run(self): + """Run worker handler.""" + while self.is_running(): + event = await self._internal_events_queue.get() + if not self.is_running(): + break + + if event == self._centinel: + continue + + _LOGGER.debug('Processing sdk internal event: %s', event.internal_event) + try: + await self._handler(event.internal_event, event.metadata) + except Exception: + _LOGGER.error('Exception raised in events manager') + _LOGGER.debug('Exception information: ', exc_info=True) + + def start(self): + """Start worker.""" + if self.is_running(): + _LOGGER.debug('SDK Event Worker is already running') + return + + self._running = True + _LOGGER.debug('Starting SDK Event Task worker') + asyncio.get_running_loop().create_task(self._run()) + + async def stop(self, stop_flag=None): + """Stop worker.""" + _LOGGER.debug('Stopping SDK Event Task worker') + if not self.is_running(): + _LOGGER.debug('SDK Event Worker is not running. Ignoring.') + return + + self._running = False + await self._internal_events_queue.put(self._centinel) \ No newline at end of file diff --git a/splitio/models/__init__.py b/splitio/models/__init__.py index e69de29b..ea86ed44 100644 --- a/splitio/models/__init__.py +++ b/splitio/models/__init__.py @@ -0,0 +1,6 @@ +class MatcherNotFoundException(Exception): + """Exception to raise when a matcher is not found.""" + + def __init__(self, custom_message): + """Constructor.""" + Exception.__init__(self, custom_message) \ No newline at end of file diff --git a/splitio/models/events.py b/splitio/models/events.py index b924417b..2863d235 100644 --- a/splitio/models/events.py +++ b/splitio/models/events.py @@ -4,7 +4,7 @@ The dto is implemented as a namedtuple for performance matters. """ from collections import namedtuple - +from enum import Enum Event = namedtuple('Event', [ 'key', @@ -19,3 +19,21 @@ 'event', 'size', ]) + +class SdkEvent(Enum): + """Public SDK events""" + + SDK_READY = 'SDK_READY' + SDK_UPDATE = 'SDK_UPDATE' + +class SdkInternalEvent(Enum): + """Internal SDK events""" + + SDK_READY = 'SDK_READY' + FLAGS_UPDATED = 'FLAGS_UPDATED' + FLAG_KILLED_NOTIFICATION = 'FLAG_KILLED_NOTIFICATION' + SEGMENTS_UPDATED = 'SEGMENTS_UPDATED' + RB_SEGMENTS_UPDATED = 'RB_SEGMENTS_UPDATED' + LARGE_SEGMENTS_UPDATED = 'LARGE_SEGMENTS_UPDATED' + + diff --git a/splitio/models/fallback_config.py b/splitio/models/fallback_config.py new file mode 100644 index 00000000..ca021bf7 --- /dev/null +++ b/splitio/models/fallback_config.py @@ -0,0 +1,100 @@ +"""Segment module.""" +from splitio.models.fallback_treatment import FallbackTreatment +from splitio.client.client import CONTROL + +class FallbackTreatmentsConfiguration(object): + """FallbackTreatmentsConfiguration object class.""" + + def __init__(self, global_fallback_treatment=None, by_flag_fallback_treatment=None): + """ + Class constructor. + + :param global_fallback_treatment: global FallbackTreatment. + :type global_fallback_treatment: FallbackTreatment + + :param by_flag_fallback_treatment: Dict of flags and their fallback treatment + :type by_flag_fallback_treatment: {str: FallbackTreatment} + """ + self._global_fallback_treatment = self._build_global_fallback(global_fallback_treatment) + self._by_flag_fallback_treatment = self._build_by_flag_fallback(by_flag_fallback_treatment) + + @property + def global_fallback_treatment(self): + """Return global fallback treatment.""" + return self._global_fallback_treatment + + @global_fallback_treatment.setter + def global_fallback_treatment(self, new_value): + """Set global fallback treatment.""" + self._global_fallback_treatment = new_value + + @property + def by_flag_fallback_treatment(self): + """Return by flag fallback treatment.""" + return self._by_flag_fallback_treatment + + @by_flag_fallback_treatment.setter + def by_flag_fallback_treatment(self, new_value): + """Set global fallback treatment.""" + self.by_flag_fallback_treatment = new_value + + def _build_global_fallback(self, global_fallback_treatment): + if isinstance(global_fallback_treatment, str): + return FallbackTreatment(global_fallback_treatment) + + return global_fallback_treatment + + def _build_by_flag_fallback(self, by_flag_fallback_treatment): + if not isinstance(by_flag_fallback_treatment, dict): + return by_flag_fallback_treatment + + parsed_by_flag_fallback = {} + for key, value in by_flag_fallback_treatment.items(): + if isinstance(value, str): + parsed_by_flag_fallback[key] = FallbackTreatment(value) + else: + parsed_by_flag_fallback[key] = value + + return parsed_by_flag_fallback + +class FallbackTreatmentCalculator(object): + """FallbackTreatmentCalculator object class.""" + + def __init__(self, fallback_treatment_configuration): + """ + Class constructor. + + :param fallback_treatment_configuration: fallback treatment configuration + :type fallback_treatment_configuration: FallbackTreatmentsConfiguration + """ + self._label_prefix = "fallback - " + self._fallback_treatments_configuration = fallback_treatment_configuration + + @property + def fallback_treatments_configuration(self): + """Return fallback treatment configuration.""" + return self._fallback_treatments_configuration + + def resolve(self, flag_name, label): + if self._fallback_treatments_configuration != None: + if self._fallback_treatments_configuration.by_flag_fallback_treatment != None \ + and self._fallback_treatments_configuration.by_flag_fallback_treatment.get(flag_name) != None: + return self._copy_with_label(self._fallback_treatments_configuration.by_flag_fallback_treatment.get(flag_name), \ + self._resolve_label(label)) + + if self._fallback_treatments_configuration.global_fallback_treatment != None: + return self._copy_with_label(self._fallback_treatments_configuration.global_fallback_treatment, \ + self._resolve_label(label)) + + return FallbackTreatment(CONTROL, None, label) + + def _resolve_label(self, label): + if label == None: + return None + + return self._label_prefix + label + + def _copy_with_label(self, fallback_treatment, label): + return FallbackTreatment(fallback_treatment.treatment, fallback_treatment.config, label) + + \ No newline at end of file diff --git a/splitio/models/fallback_treatment.py b/splitio/models/fallback_treatment.py new file mode 100644 index 00000000..794cbb63 --- /dev/null +++ b/splitio/models/fallback_treatment.py @@ -0,0 +1,34 @@ +"""Segment module.""" +import json + +class FallbackTreatment(object): + """FallbackTreatment object class.""" + + def __init__(self, treatment, config=None, label=None): + """ + Class constructor. + + :param treatment: treatment. + :type treatment: str + + :param config: config. + :type config: json + """ + self._treatment = treatment + self._config = config + self._label = label + + @property + def treatment(self): + """Return treatment.""" + return self._treatment + + @property + def config(self): + """Return config.""" + return self._config + + @property + def label(self): + """Return label prefix.""" + return self._label \ No newline at end of file diff --git a/splitio/models/grammar/condition.py b/splitio/models/grammar/condition.py index d38e6991..79fdb928 100644 --- a/splitio/models/grammar/condition.py +++ b/splitio/models/grammar/condition.py @@ -2,6 +2,7 @@ from enum import Enum +from splitio.models import MatcherNotFoundException from splitio.models.grammar import matchers from splitio.models.grammar import partitions @@ -11,7 +12,7 @@ class ConditionType(Enum): - """Split possible condition types.""" + """Feature Flag possible condition types.""" WHITELIST = 'WHITELIST' ROLLOUT = 'ROLLOUT' @@ -112,18 +113,21 @@ def from_raw(raw_condition): """ Parse a condition from a JSON portion of splitChanges. - :param raw_condition: JSON object extracted from a split's conditions array. + :param raw_condition: JSON object extracted from a feature flag's conditions array. :type raw_condition: dict :return: A condition object. :rtype: Condition """ - parsed_partitions = [ - partitions.from_raw(raw_partition) - for raw_partition in raw_condition['partitions'] - ] + parsed_partitions = [] + if raw_condition.get("partitions") is not None: + parsed_partitions = [ + partitions.from_raw(raw_partition) + for raw_partition in raw_condition['partitions'] + ] matcher_objects = [matchers.from_raw(x) for x in raw_condition['matcherGroup']['matchers']] + combiner = _MATCHER_COMBINERS[raw_condition['matcherGroup']['combiner']] label = raw_condition.get('label') diff --git a/splitio/models/grammar/matchers/__init__.py b/splitio/models/grammar/matchers/__init__.py index bab9abad..def75626 100644 --- a/splitio/models/grammar/matchers/__init__.py +++ b/splitio/models/grammar/matchers/__init__.py @@ -1,4 +1,5 @@ """Matchers entrypoint module.""" +from splitio.models import MatcherNotFoundException from splitio.models.grammar.matchers.keys import AllKeysMatcher, UserDefinedSegmentMatcher from splitio.models.grammar.matchers.numeric import BetweenMatcher, EqualToMatcher, \ GreaterThanOrEqualMatcher, LessThanOrEqualMatcher @@ -7,6 +8,9 @@ from splitio.models.grammar.matchers.string import ContainsStringMatcher, \ EndsWithMatcher, RegexMatcher, StartsWithMatcher, WhitelistMatcher from splitio.models.grammar.matchers.misc import BooleanMatcher, DependencyMatcher +from splitio.models.grammar.matchers.semver import EqualToSemverMatcher, GreaterThanOrEqualToSemverMatcher, LessThanOrEqualToSemverMatcher, \ + BetweenSemverMatcher, InListSemverMatcher +from splitio.models.grammar.matchers.rule_based_segment import RuleBasedSegmentMatcher MATCHER_TYPE_ALL_KEYS = 'ALL_KEYS' @@ -26,6 +30,12 @@ MATCHER_TYPE_IN_SPLIT_TREATMENT = 'IN_SPLIT_TREATMENT' MATCHER_TYPE_EQUAL_TO_BOOLEAN = 'EQUAL_TO_BOOLEAN' MATCHER_TYPE_MATCHES_STRING = 'MATCHES_STRING' +MATCHER_TYPE_EQUAL_TO_SEMVER = 'EQUAL_TO_SEMVER' +MATCHER_GREATER_THAN_OR_EQUAL_TO_SEMVER = 'GREATER_THAN_OR_EQUAL_TO_SEMVER' +MATCHER_LESS_THAN_OR_EQUAL_TO_SEMVER = 'LESS_THAN_OR_EQUAL_TO_SEMVER' +MATCHER_BETWEEN_SEMVER = 'BETWEEN_SEMVER' +MATCHER_INLIST_SEMVER = 'IN_LIST_SEMVER' +MATCHER_IN_RULE_BASED_SEGMENT = 'IN_RULE_BASED_SEGMENT' _MATCHER_BUILDERS = { @@ -45,10 +55,15 @@ MATCHER_TYPE_CONTAINS_STRING: ContainsStringMatcher, MATCHER_TYPE_IN_SPLIT_TREATMENT: DependencyMatcher, MATCHER_TYPE_EQUAL_TO_BOOLEAN: BooleanMatcher, - MATCHER_TYPE_MATCHES_STRING: RegexMatcher + MATCHER_TYPE_MATCHES_STRING: RegexMatcher, + MATCHER_TYPE_EQUAL_TO_SEMVER: EqualToSemverMatcher, + MATCHER_GREATER_THAN_OR_EQUAL_TO_SEMVER: GreaterThanOrEqualToSemverMatcher, + MATCHER_LESS_THAN_OR_EQUAL_TO_SEMVER: LessThanOrEqualToSemverMatcher, + MATCHER_BETWEEN_SEMVER: BetweenSemverMatcher, + MATCHER_INLIST_SEMVER: InListSemverMatcher, + MATCHER_IN_RULE_BASED_SEGMENT: RuleBasedSegmentMatcher } - def from_raw(raw_matcher): """ Parse a condition from a JSON portion of splitChanges. @@ -63,5 +78,5 @@ def from_raw(raw_matcher): try: builder = _MATCHER_BUILDERS[matcher_type] except KeyError: - raise ValueError('Invalid matcher type %s' % matcher_type) + raise MatcherNotFoundException('Invalid matcher type %s' % matcher_type) return builder(raw_matcher) diff --git a/splitio/models/grammar/matchers/base.py b/splitio/models/grammar/matchers/base.py index 0040d700..57d0feb5 100644 --- a/splitio/models/grammar/matchers/base.py +++ b/splitio/models/grammar/matchers/base.py @@ -41,6 +41,7 @@ def _get_matcher_input(self, key, attributes=None): if self._attribute_name is not None: if attributes is not None and attributes.get(self._attribute_name) is not None: return attributes[self._attribute_name] + return None if isinstance(key, Key): diff --git a/splitio/models/grammar/matchers/keys.py b/splitio/models/grammar/matchers/keys.py index 7f10fec8..0d719310 100644 --- a/splitio/models/grammar/matchers/keys.py +++ b/splitio/models/grammar/matchers/keys.py @@ -65,14 +65,11 @@ def _match(self, key, attributes=None, context=None): :returns: Wheter the match is successful. :rtype: bool """ - segment_storage = context.get('segment_storage') - if not segment_storage: - raise Exception('Segment storage not present in matcher context.') - matching_data = self._get_matcher_input(key, attributes) if matching_data is None: return False - return segment_storage.segment_contains(self._segment_name, matching_data) + + return context['ec'].segment_memberships[self._segment_name] def _add_matcher_specific_properties_to_json(self): """Return UserDefinedSegment specific properties.""" diff --git a/splitio/models/grammar/matchers/misc.py b/splitio/models/grammar/matchers/misc.py index a484db07..1f52c1fa 100644 --- a/splitio/models/grammar/matchers/misc.py +++ b/splitio/models/grammar/matchers/misc.py @@ -35,8 +35,7 @@ def _match(self, key, attributes=None, context=None): assert evaluator is not None bucketing_key = context.get('bucketing_key') - - result = evaluator.evaluate_feature(self._split_name, key, bucketing_key, attributes) + result = evaluator.eval_with_context(key, bucketing_key, self._split_name, attributes, context['ec']) return result['treatment'] in self._treatments def _add_matcher_specific_properties_to_json(self): @@ -78,6 +77,7 @@ def _match(self, key, attributes=None, context=None): matching_data = self._get_matcher_input(key, attributes) if matching_data is None: return False + if isinstance(matching_data, bool): decoded = matching_data elif isinstance(matching_data, str): @@ -85,8 +85,10 @@ def _match(self, key, attributes=None, context=None): decoded = json.loads(matching_data.lower()) if not isinstance(decoded, bool): return False + except ValueError: return False + else: return False diff --git a/splitio/models/grammar/matchers/numeric.py b/splitio/models/grammar/matchers/numeric.py index a722da0d..c39fabd7 100644 --- a/splitio/models/grammar/matchers/numeric.py +++ b/splitio/models/grammar/matchers/numeric.py @@ -106,6 +106,7 @@ def _match(self, key, attributes=None, context=None): matching_data = Sanitizer.ensure_int(self._get_matcher_input(key, attributes)) if matching_data is None: return False + return self._lower <= self.input_parsers[self._data_type](matching_data) <= self._upper def __str__(self): @@ -154,6 +155,7 @@ def _match(self, key, attributes=None, context=None): matching_data = Sanitizer.ensure_int(self._get_matcher_input(key, attributes)) if matching_data is None: return False + return self.input_parsers[self._data_type](matching_data) == self._value def _add_matcher_specific_properties_to_json(self): @@ -197,6 +199,7 @@ def _match(self, key, attributes=None, context=None): matching_data = Sanitizer.ensure_int(self._get_matcher_input(key, attributes)) if matching_data is None: return False + return self.input_parsers[self._data_type](matching_data) >= self._value def _add_matcher_specific_properties_to_json(self): @@ -240,6 +243,7 @@ def _match(self, key, attributes=None, context=None): matching_data = Sanitizer.ensure_int(self._get_matcher_input(key, attributes)) if matching_data is None: return False + return self.input_parsers[self._data_type](matching_data) <= self._value def _add_matcher_specific_properties_to_json(self): diff --git a/splitio/models/grammar/matchers/prerequisites.py b/splitio/models/grammar/matchers/prerequisites.py new file mode 100644 index 00000000..799df5c4 --- /dev/null +++ b/splitio/models/grammar/matchers/prerequisites.py @@ -0,0 +1,38 @@ +"""Prerequisites matcher classes.""" + +class PrerequisitesMatcher(object): + + def __init__(self, prerequisites): + """ + Build a PrerequisitesMatcher. + + :param prerequisites: prerequisites + :type raw_matcher: List of Prerequisites + """ + self._prerequisites = prerequisites + + def match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + if self._prerequisites == None: + return True + + evaluator = context.get('evaluator') + bucketing_key = context.get('bucketing_key') + for prerequisite in self._prerequisites: + result = evaluator.eval_with_context(key, bucketing_key, prerequisite.feature_flag_name, attributes, context['ec']) + if result['treatment'] not in prerequisite.treatments: + return False + + return True \ No newline at end of file diff --git a/splitio/models/grammar/matchers/rule_based_segment.py b/splitio/models/grammar/matchers/rule_based_segment.py new file mode 100644 index 00000000..6e4c8023 --- /dev/null +++ b/splitio/models/grammar/matchers/rule_based_segment.py @@ -0,0 +1,72 @@ +"""Rule based segment matcher classes.""" +from splitio.models.grammar.matchers.base import Matcher +from splitio.models.rule_based_segments import SegmentType + +class RuleBasedSegmentMatcher(Matcher): + + def _build(self, raw_matcher): + """ + Build an RuleBasedSegmentMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._rbs_segment_name = raw_matcher['userDefinedSegmentMatcherData']['segmentName'] + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + if self._rbs_segment_name == None: + return False + + rb_segment = context['ec'].rbs_segments.get(self._rbs_segment_name) + + if key in rb_segment.excluded.get_excluded_keys(): + return False + + if self._match_dep_rb_segments(rb_segment.excluded.get_excluded_segments(), key, attributes, context): + return False + + return self._match_conditions(rb_segment.conditions, key, attributes, context) + + def _add_matcher_specific_properties_to_json(self): + """Return UserDefinedSegment specific properties.""" + return { + 'userDefinedSegmentMatcherData': { + 'segmentName': self._rbs_segment_name + } + } + + def _match_conditions(self, rbs_segment_conditions, key, attributes, context): + for parsed_condition in rbs_segment_conditions: + if parsed_condition.matches(key, attributes, context): + return True + + return False + + def _match_dep_rb_segments(self, excluded_rb_segments, key, attributes, context): + for excluded_rb_segment in excluded_rb_segments: + if excluded_rb_segment.type == SegmentType.STANDARD: + if context['ec'].segment_memberships[excluded_rb_segment.name]: + return True + else: + excluded_segment = context['ec'].rbs_segments.get(excluded_rb_segment.name) + if key in excluded_segment.excluded.get_excluded_keys(): + return False + + if self._match_dep_rb_segments(excluded_segment.excluded.get_excluded_segments(), key, attributes, context) \ + or self._match_conditions(excluded_segment.conditions, key, attributes, context): + return True + + return False diff --git a/splitio/models/grammar/matchers/semver.py b/splitio/models/grammar/matchers/semver.py new file mode 100644 index 00000000..46ccf01d --- /dev/null +++ b/splitio/models/grammar/matchers/semver.py @@ -0,0 +1,260 @@ +"""Semver matcher classes.""" +import logging + +from splitio.models.grammar.matchers.base import Matcher +from splitio.models.grammar.matchers.string import Sanitizer +from splitio.models.grammar.matchers.utils.utils import build_semver_or_none + + +_LOGGER = logging.getLogger(__name__) + + +class EqualToSemverMatcher(Matcher): + """A matcher for Semver equal to.""" + + def _build(self, raw_matcher): + """ + Build an EqualToSemverMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._data = raw_matcher.get('stringMatcherData') + self._semver = build_semver_or_none(raw_matcher.get('stringMatcherData')) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + if self._semver is None: + _LOGGER.error("stringMatcherData is required for EQUAL_TO_SEMVER matcher type") + return False + + matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) + if matching_data is None: + return False + + matching_semver = build_semver_or_none(matching_data) + if matching_semver is None: + return False + + return self._semver.version == matching_semver.version + + def __str__(self): + """Return string Representation.""" + return f'equal semver {self._data}' + + def _add_matcher_specific_properties_to_json(self): + """Add matcher specific properties to base dict before returning it.""" + return {'matcherType': 'EQUAL_TO_SEMVER', 'stringMatcherData': self._data} + +class GreaterThanOrEqualToSemverMatcher(Matcher): + """A matcher for Semver greater than or equal to.""" + + def _build(self, raw_matcher): + """ + Build a GreaterThanOrEqualToSemverMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._data = raw_matcher.get('stringMatcherData') + self._semver = build_semver_or_none(raw_matcher.get('stringMatcherData')) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + if self._semver is None: + _LOGGER.error("stringMatcherData is required for GREATER_THAN_OR_EQUAL_TO_SEMVER matcher type") + return False + + matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) + if matching_data is None: + return False + + matching_semver = build_semver_or_none(matching_data) + if matching_semver is None: + return False + + return matching_semver.compare(self._semver) in [0, 1] + + def __str__(self): + """Return string Representation.""" + return f'greater than or equal to semver {self._data}' + + def _add_matcher_specific_properties_to_json(self): + """Add matcher specific properties to base dict before returning it.""" + return {'matcherType': 'GREATER_THAN_OR_EQUAL_TO_SEMVER', 'stringMatcherData': self._data} + + +class LessThanOrEqualToSemverMatcher(Matcher): + """A matcher for Semver less than or equal to.""" + + def _build(self, raw_matcher): + """ + Build a LessThanOrEqualToSemverMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._data = raw_matcher.get('stringMatcherData') + self._semver = build_semver_or_none(raw_matcher.get('stringMatcherData')) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + if self._semver is None: + _LOGGER.error("stringMatcherData is required for LESS_THAN_OR_EQUAL_TO_SEMVER matcher type") + return False + + matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) + if matching_data is None: + return False + + matching_semver = build_semver_or_none(matching_data) + if matching_semver is None: + return False + + return matching_semver.compare(self._semver) in [0, -1] + + def __str__(self): + """Return string Representation.""" + return f'less than or equal to semver {self._data}' + + def _add_matcher_specific_properties_to_json(self): + """Add matcher specific properties to base dict before returning it.""" + return {'matcherType': 'LESS_THAN_OR_EQUAL_TO_SEMVER', 'stringMatcherData': self._data} + + +class BetweenSemverMatcher(Matcher): + """A matcher for Semver between.""" + + def _build(self, raw_matcher): + """ + Build a BetweenSemverMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._data = raw_matcher.get('betweenStringMatcherData') + self._semver_start = build_semver_or_none(self._data['start']) + self._semver_end = build_semver_or_none(self._data['end']) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + if self._semver_start is None or self._semver_end is None: + _LOGGER.error("betweenStringMatcherData is required for BETWEEN_SEMVER matcher type") + return False + + matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) + if matching_data is None: + return False + + matching_semver = build_semver_or_none(matching_data) + if matching_semver is None: + return False + + return (self._semver_start.compare(matching_semver) in [0, -1]) and (self._semver_end.compare(matching_semver) in [0, 1]) + + def __str__(self): + """Return string Representation.""" + return 'between semver {start} and {end}'.format(start=self._data.get('start'), end=self._data.get('end')) + + def _add_matcher_specific_properties_to_json(self): + """Add matcher specific properties to base dict before returning it.""" + return {'matcherType': 'BETWEEN_SEMVER', 'betweenStringMatcherData': self._data} + + +class InListSemverMatcher(Matcher): + """A matcher for Semver in list.""" + + def _build(self, raw_matcher): + """ + Build a InListSemverMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._data = raw_matcher['whitelistMatcherData']['whitelist'] + semver_list = [build_semver_or_none(item) for item in self._data if item] + self._semver_list = frozenset([item.version for item in semver_list if item]) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + if self._semver_list is None: + _LOGGER.error("whitelistMatcherData is required for IN_LIST_SEMVER matcher type") + return False + + matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) + if matching_data is None: + return False + + matching_semver = build_semver_or_none(matching_data) + if matching_semver is None: + return False + + return matching_semver.version in self._semver_list + + def __str__(self): + """Return string Representation.""" + return 'in list semver {data}'.format(data=self._data) + + def _add_matcher_specific_properties_to_json(self): + """Add matcher specific properties to base dict before returning it.""" + return {'matcherType': 'IN_LIST_SEMVER', 'whitelistMatcherData': {'whitelist': self._data}} diff --git a/splitio/models/grammar/matchers/sets.py b/splitio/models/grammar/matchers/sets.py index 49890a98..f46970b4 100644 --- a/splitio/models/grammar/matchers/sets.py +++ b/splitio/models/grammar/matchers/sets.py @@ -31,9 +31,11 @@ def _match(self, key, attributes=None, context=None): matching_data = self._get_matcher_input(key, attributes) if matching_data is None: return False + try: setkey = set(matching_data) return self._whitelist.issubset(setkey) + except TypeError: return False @@ -81,8 +83,10 @@ def _match(self, key, attributes=None, context=None): matching_data = self._get_matcher_input(key, attributes) if matching_data is None: return False + try: return len(self._whitelist.intersection(set(matching_data))) != 0 + except TypeError: return False @@ -130,8 +134,10 @@ def _match(self, key, attributes=None, context=None): matching_data = self._get_matcher_input(key, attributes) if matching_data is None: return False + try: return self._whitelist == set(matching_data) + except TypeError: return False @@ -179,9 +185,11 @@ def _match(self, key, attributes=None, context=None): matching_data = self._get_matcher_input(key, attributes) if matching_data is None: return False + try: setkey = set(matching_data) return len(setkey) > 0 and setkey.issubset(set(self._whitelist)) + except TypeError: return False diff --git a/splitio/models/grammar/matchers/string.py b/splitio/models/grammar/matchers/string.py index 788972c6..1a820b21 100644 --- a/splitio/models/grammar/matchers/string.py +++ b/splitio/models/grammar/matchers/string.py @@ -35,6 +35,7 @@ def ensure_string(cls, data): ) try: return json.dumps(data) + except TypeError: return None @@ -68,6 +69,7 @@ def _match(self, key, attributes=None, context=None): matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) if matching_data is None: return False + return matching_data in self._whitelist def _add_matcher_specific_properties_to_json(self): @@ -114,6 +116,7 @@ def _match(self, key, attributes=None, context=None): matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) if matching_data is None: return False + return (isinstance(key, str) and any(matching_data.startswith(s) for s in self._whitelist)) @@ -161,6 +164,7 @@ def _match(self, key, attributes=None, context=None): matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) if matching_data is None: return False + return (isinstance(key, str) and any(matching_data.endswith(s) for s in self._whitelist)) @@ -208,6 +212,7 @@ def _match(self, key, attributes=None, context=None): matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) if matching_data is None: return False + return (isinstance(matching_data, str) and any(s in matching_data for s in self._whitelist)) @@ -256,9 +261,11 @@ def _match(self, key, attributes=None, context=None): matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) if matching_data is None: return False + try: matches = re.search(self._regex, matching_data) return matches is not None + except TypeError: return False diff --git a/splitio/models/grammar/matchers/utils/__init__.py b/splitio/models/grammar/matchers/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/splitio/models/grammar/matchers/utils/utils.py b/splitio/models/grammar/matchers/utils/utils.py new file mode 100644 index 00000000..d0b2e727 --- /dev/null +++ b/splitio/models/grammar/matchers/utils/utils.py @@ -0,0 +1,168 @@ +"""Utils module.""" + +import logging + +_LOGGER = logging.getLogger(__name__) + +M_DELIMITER = "+" +P_DELIMITER = "-" +V_DELIMITER = "." + + +def compare(var1, var2): + """ + Compare 2 variables and return int as follows: + 0: if var1 == var2 + 1: if var1 > var2 + -1: if var1 < var2 + + :param var1: any object accept ==, < or > operators + :type var1: str/int + :param var2: any object accept ==, < or > operators + :type var2: str/int + + :returns: integer based on comparison + :rtype: int + """ + if var1 == var2: + return 0 + if var1 > var2: + return 1 + return -1 + + +def build_semver_or_none(version): + try: + return Semver(version) + except (RuntimeError, ValueError): + _LOGGER.error("Invalid semver version: %s", version) + return None + + +class Semver(object): + """Semver class.""" + + def __init__(self, version): + """ + Class Initializer + + :param version: raw version as read from splitChanges response. + :type version: str + """ + self._major = 0 + self._minor = 0 + self._patch = 0 + self._pre_release = [] + self._is_stable = False + self._version = "" + self._metadata = "" + self._parse(version) + + def _parse(self, version): + """ + Parse the string in self.version to update the other internal variables + """ + without_metadata = self._extract_metadata(version) + index = without_metadata.find(P_DELIMITER) + if index == -1: + self._is_stable = True + else: + pre_release_data = without_metadata[index+1:] + if pre_release_data == "": + raise RuntimeError("Pre-release is empty despite delimiter exists: " + version) + + without_metadata = without_metadata[:index] + for pre_digit in pre_release_data.split(V_DELIMITER): + if pre_digit.isnumeric(): + pre_digit = str(int(pre_digit)) + self._pre_release.append(pre_digit) + + self._set_components(without_metadata) + + def _extract_metadata(self, version): + """ + Check if there is any metadata characters in self.version. + + :returns: The semver string without the metadata + :rtype: str + """ + index = version.find(M_DELIMITER) + if index == -1: + return version + + self._metadata = version[index+1:] + if self._metadata == "": + raise RuntimeError("Metadata is empty despite delimiter exists: " + version) + + return version[:index] + + def _set_components(self, version): + """ + Set the major, minor and patch internal variables based on string passed. + + :param version: raw version containing major.minor.patch numbers. + :type version: str + """ + + parts = version.split(V_DELIMITER) + if len(parts) != 3: + raise RuntimeError("Unable to convert to Semver, incorrect format: " + version) + try: + self._major, self._minor, self._patch = int(parts[0]), int(parts[1]), int(parts[2]) + self._version = f"{self._major}{V_DELIMITER}{self._minor}{V_DELIMITER}{self._patch}" + self._version += f"{P_DELIMITER + V_DELIMITER.join(self._pre_release) if len(self._pre_release) > 0 else ''}" + self._version += f"{M_DELIMITER + self._metadata if self._metadata else ''}" + except Exception: + raise RuntimeError("Unable to convert to Semver, incorrect format: " + version) + + @property + def version(self): + return self._version + + def compare(self, to_compare): + """ + Compare the current Semver object to a given Semver object, return: + 0: if self == passed + 1: if self > passed + -1: if self < passed + + :param to_compare: a Semver object + :type to_compare: splitio.models.grammar.matchers.semver.Semver + + :returns: integer based on comparison + :rtype: int + """ + if self.version == to_compare.version: + return 0 + + # Compare major, minor, and patch versions numerically + result = compare(self._major, to_compare._major) + if result != 0: + return result + + result = compare(self._minor, to_compare._minor) + if result != 0: + return result + + result = compare(self._patch, to_compare._patch) + if result != 0: + return result + + if not self._is_stable and to_compare._is_stable: + return -1 + elif self._is_stable and not to_compare._is_stable: + return 1 + + # Compare pre-release versions lexically + min_length = min(len(self._pre_release), len(to_compare._pre_release)) + for i in range(min_length): + if self._pre_release[i] == to_compare._pre_release[i]: + continue + + if self._pre_release[i].isnumeric() and to_compare._pre_release[i].isnumeric(): + return compare(int(self._pre_release[i]), int(to_compare._pre_release[i])) + + return compare(self._pre_release[i], to_compare._pre_release[i]) + + # Compare lengths of pre-release versions + return compare(len(self._pre_release), len(to_compare._pre_release)) diff --git a/splitio/models/impressions.py b/splitio/models/impressions.py index b08d31fb..0c6d50f7 100644 --- a/splitio/models/impressions.py +++ b/splitio/models/impressions.py @@ -12,7 +12,16 @@ 'change_number', 'bucketing_key', 'time', - 'previous_time' + 'previous_time', + 'properties' + ] +) + +ImpressionDecorated = namedtuple( + 'ImpressionDecorated', + [ + 'Impression', + 'disabled' ] ) @@ -52,3 +61,8 @@ class Label(object): # pylint: disable=too-few-public-methods # Treatment: control # Label: not ready NOT_READY = 'not ready' + + # Condition: Prerequisites not met + # Treatment: Default treatment + # Label: prerequisites not met + PREREQUISITES_NOT_MET = "prerequisites not met" diff --git a/splitio/models/notification.py b/splitio/models/notification.py index ebe57175..60b629e1 100644 --- a/splitio/models/notification.py +++ b/splitio/models/notification.py @@ -170,6 +170,29 @@ def notification_type(self): def split_name(self): return self._split_name +class SdkInternalEventNotification(object): # pylint: disable=too-many-instance-attributes + """SdkInternalEventNotification model object.""" + + def __init__(self, internal_event, metadata): + """ + Class constructor. + + :param internal_event: internal event object + :type channel: SdkInternalEvent + :param metadata: metadata associated with event + :type change_number: EventsMetadata + + """ + self._internal_event = internal_event + self._metadata = metadata + + @property + def internal_event(self): + return self._internal_event + + @property + def metadata(self): + return self._metadata _NOTIFICATION_MAPPERS = { Type.SPLIT_UPDATE: lambda c, d: SplitChangeNotification(c, Type.SPLIT_UPDATE, d['changeNumber']), @@ -195,6 +218,7 @@ def wrap_notification(raw_data, channel): notification_type = Type(raw_data['type']) mapper = _NOTIFICATION_MAPPERS[notification_type] return mapper(channel, raw_data) + except ValueError: raise ValueError("Wrong notification type received.") except KeyError: diff --git a/splitio/models/rule_based_segments.py b/splitio/models/rule_based_segments.py new file mode 100644 index 00000000..f7bf3f4d --- /dev/null +++ b/splitio/models/rule_based_segments.py @@ -0,0 +1,195 @@ +"""RuleBasedSegment module.""" + +from enum import Enum +import logging + +from splitio.models import MatcherNotFoundException +from splitio.models.splits import _DEFAULT_CONDITIONS_TEMPLATE +from splitio.models.grammar import condition +from splitio.models.splits import Status + +_LOGGER = logging.getLogger(__name__) + +class SegmentType(Enum): + """Segment type.""" + + STANDARD = "standard" + RULE_BASED = "rule-based" + +class RuleBasedSegment(object): + """RuleBasedSegment object class.""" + + def __init__(self, name, traffic_type_name, change_number, status, conditions, excluded): + """ + Class constructor. + + :param name: Segment name. + :type name: str + :param traffic_type_name: traffic type name. + :type traffic_type_name: str + :param change_number: change number. + :type change_number: str + :param status: status. + :type status: str + :param conditions: List of conditions belonging to the segment. + :type conditions: List + :param excluded: excluded objects. + :type excluded: Excluded + """ + self._name = name + self._traffic_type_name = traffic_type_name + self._change_number = change_number + self._conditions = conditions + self._excluded = excluded + try: + self._status = Status(status) + except ValueError: + self._status = Status.ARCHIVED + + @property + def name(self): + """Return segment name.""" + return self._name + + @property + def traffic_type_name(self): + """Return traffic type name.""" + return self._traffic_type_name + + @property + def change_number(self): + """Return change number.""" + return self._change_number + + @property + def status(self): + """Return status.""" + return self._status + + @property + def conditions(self): + """Return conditions.""" + return self._conditions + + @property + def excluded(self): + """Return excluded.""" + return self._excluded + + def to_json(self): + """Return a JSON representation of this rule based segment.""" + return { + 'changeNumber': self.change_number, + 'trafficTypeName': self.traffic_type_name, + 'name': self.name, + 'status': self.status.value, + 'conditions': [c.to_json() for c in self.conditions], + 'excluded': self.excluded.to_json() + } + + def get_condition_segment_names(self): + segments = set() + for condition in self._conditions: + for matcher in condition.matchers: + if matcher._matcher_type == 'IN_SEGMENT': + segments.add(matcher.to_json()['userDefinedSegmentMatcherData']['segmentName']) + return segments + +def from_raw(raw_rule_based_segment): + """ + Parse a Rule based segment from a JSON portion of splitChanges. + + :param raw_rule_based_segment: JSON object extracted from a splitChange's response + :type raw_rule_based_segment: dict + + :return: A parsed RuleBasedSegment object capable of performing evaluations. + :rtype: RuleBasedSegment + """ + try: + conditions = [condition.from_raw(c) for c in raw_rule_based_segment['conditions']] + except MatcherNotFoundException as e: + _LOGGER.error(str(e)) + _LOGGER.debug("Using default conditions template for feature flag: %s", raw_rule_based_segment['name']) + conditions = [condition.from_raw(_DEFAULT_CONDITIONS_TEMPLATE)] + + if raw_rule_based_segment.get('excluded') == None: + raw_rule_based_segment['excluded'] = {'keys': [], 'segments': []} + + if raw_rule_based_segment['excluded'].get('keys') == None: + raw_rule_based_segment['excluded']['keys'] = [] + + if raw_rule_based_segment['excluded'].get('segments') == None: + raw_rule_based_segment['excluded']['segments'] = [] + + return RuleBasedSegment( + raw_rule_based_segment['name'], + raw_rule_based_segment['trafficTypeName'], + raw_rule_based_segment['changeNumber'], + raw_rule_based_segment['status'], + conditions, + Excluded(raw_rule_based_segment['excluded']['keys'], raw_rule_based_segment['excluded']['segments']) + ) + +class Excluded(object): + + def __init__(self, keys, segments): + """ + Class constructor. + + :param keys: List of excluded keys in a rule based segment. + :type keys: List + :param segments: List of excluded segments in a rule based segment. + :type segments: List + """ + self._keys = keys + self._segments = [ExcludedSegment(segment['name'], segment['type']) for segment in segments] + + def get_excluded_keys(self): + """Return excluded keys.""" + return self._keys + + def get_excluded_segments(self): + """Return excluded segments""" + return self._segments + + def get_excluded_standard_segments(self): + """Return excluded segments""" + to_return = [] + for segment in self._segments: + if segment.type == SegmentType.STANDARD: + to_return.append(segment.name) + return to_return + + def to_json(self): + """Return a JSON representation of this object.""" + return { + 'keys': self._keys, + 'segments': self._segments + } + +class ExcludedSegment(object): + + def __init__(self, name, type): + """ + Class constructor. + + :param name: rule based segment name + :type name: str + :param type: segment type + :type type: str + """ + self._name = name + try: + self._type = SegmentType(type) + except ValueError: + self._type = SegmentType.STANDARD + + @property + def name(self): + """Return name.""" + return self._name + + @property + def type(self): + """Return type.""" + return self._type diff --git a/splitio/models/splits.py b/splitio/models/splits.py index 5e0ab394..47e69284 100644 --- a/splitio/models/splits.py +++ b/splitio/models/splits.py @@ -1,15 +1,67 @@ """Splits module.""" from enum import Enum from collections import namedtuple +import logging +from splitio.models import MatcherNotFoundException from splitio.models.grammar import condition +_LOGGER = logging.getLogger(__name__) SplitView = namedtuple( 'SplitView', - ['name', 'traffic_type', 'killed', 'treatments', 'change_number', 'configs'] + ['name', 'traffic_type', 'killed', 'treatments', 'change_number', 'configs', 'default_treatment', 'sets', 'impressions_disabled', 'prerequisites'] ) +_DEFAULT_CONDITIONS_TEMPLATE = { + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": None, + "matcherType": "ALL_KEYS", + "negate": False, + "userDefinedSegmentMatcherData": None, + "whitelistMatcherData": None, + "unaryNumericMatcherData": None, + "betweenMatcherData": None, + "dependencyMatcherData": None, + "booleanMatcherData": None, + "stringMatcherData": None + }] + }, + "partitions": [ + { + "treatment": "control", + "size": 100 + } + ], + "label": "targeting rule type unsupported by sdk" +} + +class Prerequisites(object): + """Prerequisites.""" + def __init__(self, feature_flag_name, treatments): + self._feature_flag_name = feature_flag_name + self._treatments = treatments + + @property + def feature_flag_name(self): + """Return featur eflag name.""" + return self._feature_flag_name + + @property + def treatments(self): + """Return treatments.""" + return self._treatments + + def to_json(self): + to_return = [] + for feature_flag_name in self._feature_flag_name: + to_return.append({"n": feature_flag_name, "ts": [treatment for treatment in self._treatments]}) + + return to_return class Status(Enum): """Split status.""" @@ -41,7 +93,10 @@ def __init__( # pylint: disable=too-many-arguments algo=None, traffic_allocation=None, traffic_allocation_seed=None, - configurations=None + configurations=None, + sets=None, + impressions_disabled=None, + prerequisites = None ): """ Class constructor. @@ -62,6 +117,12 @@ def __init__( # pylint: disable=too-many-arguments :type traffic_allocation: int :pram traffic_allocation_seed: Seed used to hash traffic allocation. :type traffic_allocation_seed: int + :pram sets: list of flag sets + :type sets: list + :pram impressions_disabled: track impressions flag + :type impressions_disabled: boolean + :pram prerequisites: prerequisites + :type prerequisites: List of Preqreuisites """ self._name = name self._seed = seed @@ -90,6 +151,9 @@ def __init__( # pylint: disable=too-many-arguments self._algo = HashAlgorithm.LEGACY self._configurations = configurations + self._sets = set(sets) if sets is not None else set() + self._impressions_disabled = impressions_disabled if impressions_disabled is not None else False + self._prerequisites = prerequisites if prerequisites is not None else [] @property def name(self): @@ -146,6 +210,21 @@ def traffic_allocation_seed(self): """Return the traffic allocation seed of the split.""" return self._traffic_allocation_seed + @property + def sets(self): + """Return the flag sets of the split.""" + return self._sets + + @property + def impressions_disabled(self): + """Return impressions_disabled of the split.""" + return self._impressions_disabled + + @property + def prerequisites(self): + """Return prerequisites of the split.""" + return self._prerequisites + def get_configurations_for(self, treatment): """Return the mapping of treatments to configurations.""" return self._configurations.get(treatment) if self._configurations else None @@ -173,7 +252,10 @@ def to_json(self): 'defaultTreatment': self.default_treatment, 'algo': self.algo.value, 'conditions': [c.to_json() for c in self.conditions], - 'configurations': self._configurations + 'configurations': self._configurations, + 'sets': list(self._sets), + 'impressionsDisabled': self._impressions_disabled, + 'prerequisites': [prerequisite.to_json() for prerequisite in self._prerequisites] } def to_split_view(self): @@ -189,7 +271,11 @@ def to_split_view(self): self.killed, list(set(part.treatment for cond in self.conditions for part in cond.partitions)), self.change_number, - self._configurations if self._configurations is not None else {} + self._configurations if self._configurations is not None else {}, + self._default_treatment, + list(self._sets) if self._sets is not None else [], + self._impressions_disabled, + self._prerequisites ) def local_kill(self, default_treatment, change_number): @@ -226,6 +312,12 @@ def from_raw(raw_split): :return: A parsed Split object capable of performing evaluations. :rtype: Split """ + try: + conditions = [condition.from_raw(c) for c in raw_split['conditions']] + except MatcherNotFoundException as e: + _LOGGER.error(str(e)) + _LOGGER.debug("Using default conditions template for feature flag: %s", raw_split['name']) + conditions = [condition.from_raw(_DEFAULT_CONDITIONS_TEMPLATE)] return Split( raw_split['name'], raw_split['seed'], @@ -234,9 +326,19 @@ def from_raw(raw_split): raw_split['trafficTypeName'], raw_split['status'], raw_split['changeNumber'], - [condition.from_raw(c) for c in raw_split['conditions']], + conditions, raw_split.get('algo'), traffic_allocation=raw_split.get('trafficAllocation'), traffic_allocation_seed=raw_split.get('trafficAllocationSeed'), - configurations=raw_split.get('configurations') + configurations=raw_split.get('configurations'), + sets=set(raw_split.get('sets')) if raw_split.get('sets') is not None else [], + impressions_disabled=raw_split.get('impressionsDisabled') if raw_split.get('impressionsDisabled') is not None else False, + prerequisites=from_raw_prerequisites(raw_split.get('prerequisites')) if raw_split.get('prerequisites') is not None else [] ) + +def from_raw_prerequisites(raw_prerequisites): + to_return = [] + for prerequisite in raw_prerequisites: + to_return.append(Prerequisites(prerequisite['n'], prerequisite['ts'])) + + return to_return \ No newline at end of file diff --git a/splitio/models/telemetry.py b/splitio/models/telemetry.py index e4739328..c9715da4 100644 --- a/splitio/models/telemetry.py +++ b/splitio/models/telemetry.py @@ -1,6 +1,12 @@ """SDK Telemetry helpers.""" from bisect import bisect_left +import threading +import os +from enum import Enum +import abc +from splitio.engine.impressions import ImpressionsMode +from splitio.optional.loaders import asyncio BUCKETS = ( 1000, 1500, 2250, 3375, 5063, @@ -9,8 +15,132 @@ 437894, 656841, 985261, 1477892, 2216838, 3325257, 4987885, 7481828 ) + MAX_LATENCY = 7481828 +MAX_LATENCY_BUCKET_COUNT = 23 +MAX_STREAMING_EVENTS = 20 +MAX_TAGS = 10 + +class CounterConstants(Enum): + """Impressions and events counters constants""" + IMPRESSIONS_QUEUED = 'impressionsQueued' + IMPRESSIONS_DEDUPED = 'impressionsDeduped' + IMPRESSIONS_DROPPED = 'impressionsDropped' + EVENTS_QUEUED = 'eventsQueued' + EVENTS_DROPPED = 'eventsDropped' + +class _ConfigParams(Enum): + """Config parameters constants""" + SPLITS_REFRESH_RATE = 'featuresRefreshRate' + SEGMENTS_REFRESH_RATE = 'segmentsRefreshRate' + IMPRESSIONS_REFRESH_RATE = 'impressionsRefreshRate' + EVENTS_REFRESH_RATE = 'eventsPushRate' + TELEMETRY_REFRESH_RATE = 'metricsRefreshRate' + OPERATION_MODE = 'operationMode' + STORAGE_TYPE = 'storageType' + STREAMING_ENABLED = 'streamingEnabled' + IMPRESSIONS_QUEUE_SIZE = 'impressionsQueueSize' + EVENTS_QUEUE_SIZE = 'eventsQueueSize' + IMPRESSIONS_MODE = 'impressionsMode' + IMPRESSIONS_LISTENER = 'impressionListener' + +class _ExtraConfig(Enum): + """Extra config constants""" + ACTIVE_FACTORY_COUNT = 'activeFactoryCount' + REDUNDANT_FACTORY_COUNT = 'redundantFactoryCount' + BLOCK_UNTIL_READY_TIMEOUT = 'blockUntilReadyTimeout' + NOT_READY = 'notReady' + TIME_UNTIL_READY = 'timeUntilReady' + REFRESH_RATE = 'refreshRate' + HTTP_PROXY = 'httpProxy' + HTTPS_PROXY_ENV = 'HTTPS_PROXY' + +class _ApiURLs(Enum): + """Api URL constants""" + SDK_URL = 'sdk_url' + EVENTS_URL = 'events_url' + AUTH_URL = 'auth_url' + STREAMING_URL = 'streaming_url' + TELEMETRY_URL = 'telemetry_url' + URL_OVERRIDE = 'urlOverride' + +class HTTPExceptionsAndLatencies(Enum): + """Sync exceptions and latencies constants""" + HTTP_ERRORS = 'httpErrors' + HTTP_LATENCIES = 'httpLatencies' + SPLIT = 'split' + SEGMENT = 'segment' + IMPRESSION = 'impression' + IMPRESSION_COUNT = 'impressionCount' + EVENT = 'event' + TELEMETRY = 'telemetry' + TOKEN = 'token' + +class MethodExceptionsAndLatencies(Enum): + """Method exceptions and latencies constants""" + METHOD_LATENCIES = 'methodLatencies' + METHOD_EXCEPTIONS = 'methodExceptions' + TREATMENT = 'treatment' + TREATMENTS = 'treatments' + TREATMENT_WITH_CONFIG = 'treatment_with_config' + TREATMENTS_WITH_CONFIG = 'treatments_with_config' + TREATMENTS_BY_FLAG_SET = 'treatments_by_flag_set' + TREATMENTS_BY_FLAG_SETS = 'treatments_by_flag_sets' + TREATMENTS_WITH_CONFIG_BY_FLAG_SET = 'treatments_with_config_by_flag_set' + TREATMENTS_WITH_CONFIG_BY_FLAG_SETS = 'treatments_with_config_by_flag_sets' + TRACK = 'track' + +class _LastSynchronizationConstants(Enum): + """Last sync constants""" + LAST_SYNCHRONIZATIONS = 'lastSynchronizations' + +class SSEStreamingStatus(Enum): + """SSE streaming status enums""" + ENABLED = 0 + DISABLED = 1 + PAUSED = 2 + +class SSEConnectionError(Enum): + """SSE Connection Error enums""" + REQUESTED = 0 + NON_REQUESTED = 1 + +class SSESyncMode(Enum): + """SSE sync mode enums""" + STREAMING = 0 + POLLING = 1 +class _StreamingEventsConstant(Enum): + """Storage types constant""" + STREAMING_EVENTS = 'streamingEvents' + +class StreamingEventTypes(Enum): + """Streaming event types constants""" + CONNECTION_ESTABLISHED = 0 + OCCUPANCY_PRI = 10 + OCCUPANCY_SEC = 20 + STREAMING_STATUS = 30 + SSE_CONNECTION_ERROR = 40 + TOKEN_REFRESH = 50 + ABLY_ERROR = 60 + SYNC_MODE_UPDATE = 70 + +class StorageType(Enum): + """Storage types constants""" + MEMORY = 'memory' + REDIS = 'redis' + PLUGGABLE = 'pluggable' + +class OperationMode(Enum): + """Storage modes constants""" + STANDALONE = 'standalone' + CONSUMER = 'consumer' + PARTIAL_CONSUMER = 'partial_consumer' + +class UpdateFromSSE(Enum): + """Update from sse constants""" + SPLIT_UPDATE = 'sp' + RBS_UPDATE = 'rbs' def get_latency_bucket_index(micros): """ @@ -25,3 +155,1801 @@ def get_latency_bucket_index(micros): return len(BUCKETS) - 1 return bisect_left(BUCKETS, micros) + +class MethodLatenciesBase(object, metaclass=abc.ABCMeta): + """ + Method Latency base class + + """ + def _reset_all(self): + """Reset variables""" + self._treatment = [0] * MAX_LATENCY_BUCKET_COUNT + self._treatments = [0] * MAX_LATENCY_BUCKET_COUNT + self._treatment_with_config = [0] * MAX_LATENCY_BUCKET_COUNT + self._treatments_with_config = [0] * MAX_LATENCY_BUCKET_COUNT + self._treatments_by_flag_set = [0] * MAX_LATENCY_BUCKET_COUNT + self._treatments_by_flag_sets = [0] * MAX_LATENCY_BUCKET_COUNT + self._treatments_with_config_by_flag_set = [0] * MAX_LATENCY_BUCKET_COUNT + self._treatments_with_config_by_flag_sets = [0] * MAX_LATENCY_BUCKET_COUNT + self._track = [0] * MAX_LATENCY_BUCKET_COUNT + + @abc.abstractmethod + def add_latency(self, method, latency): + """ + Add Latency method + """ + + @abc.abstractmethod + def pop_all(self): + """ + Pop all latencies + """ + +class MethodLatencies(MethodLatenciesBase): + """ + Method Latency class + + """ + def __init__(self): + """Constructor""" + self._lock = threading.RLock() + with self._lock: + self._reset_all() + + def add_latency(self, method, latency): + """ + Add Latency method + + :param method: passed method name + :type method: str + :param latency: amount of latency in microseconds + :type latency: int + """ + latency_bucket = get_latency_bucket_index(latency) + with self._lock: + if method == MethodExceptionsAndLatencies.TREATMENT: + self._treatment[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS: + self._treatments[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG: + self._treatment_with_config[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG: + self._treatments_with_config[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET: + self._treatments_by_flag_set[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS: + self._treatments_by_flag_sets[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET: + self._treatments_with_config_by_flag_set[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS: + self._treatments_with_config_by_flag_sets[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TRACK: + self._track[latency_bucket] += 1 + else: + return + + def pop_all(self): + """ + Pop all latencies + + :return: Dictonary of latencies + :rtype: dict + """ + with self._lock: + latencies = {MethodExceptionsAndLatencies.METHOD_LATENCIES.value: { + MethodExceptionsAndLatencies.TREATMENT.value: self._treatment, + MethodExceptionsAndLatencies.TREATMENTS.value: self._treatments, + MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG.value: self._treatment_with_config, + MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG.value: self._treatments_with_config, + MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET.value: self._treatments_by_flag_set, + MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS.value: self._treatments_by_flag_sets, + MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET.value: self._treatments_with_config_by_flag_set, + MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS.value: self._treatments_with_config_by_flag_sets, + MethodExceptionsAndLatencies.TRACK.value: self._track} + } + self._reset_all() + return latencies + + +class MethodLatenciesAsync(MethodLatenciesBase): + """ + Method async Latency class + + """ + @classmethod + async def create(cls): + """Constructor""" + self = cls() + self._lock = asyncio.Lock() + async with self._lock: + self._reset_all() + return self + + async def add_latency(self, method, latency): + """ + Add Latency method + + :param method: passed method name + :type method: str + :param latency: amount of latency in microseconds + :type latency: int + """ + latency_bucket = get_latency_bucket_index(latency) + async with self._lock: + if method == MethodExceptionsAndLatencies.TREATMENT: + self._treatment[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS: + self._treatments[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG: + self._treatment_with_config[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG: + self._treatments_with_config[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET: + self._treatments_by_flag_set[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS: + self._treatments_by_flag_sets[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET: + self._treatments_with_config_by_flag_set[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS: + self._treatments_with_config_by_flag_sets[latency_bucket] += 1 + elif method == MethodExceptionsAndLatencies.TRACK: + self._track[latency_bucket] += 1 + else: + return + + async def pop_all(self): + """ + Pop all latencies + + :return: Dictonary of latencies + :rtype: dict + """ + async with self._lock: + latencies = {MethodExceptionsAndLatencies.METHOD_LATENCIES.value: { + MethodExceptionsAndLatencies.TREATMENT.value: self._treatment, + MethodExceptionsAndLatencies.TREATMENTS.value: self._treatments, + MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG.value: self._treatment_with_config, + MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG.value: self._treatments_with_config, + MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET.value: self._treatments_by_flag_set, + MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS.value: self._treatments_by_flag_sets, + MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET.value: self._treatments_with_config_by_flag_set, + MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS.value: self._treatments_with_config_by_flag_sets, + MethodExceptionsAndLatencies.TRACK.value: self._track} + } + self._reset_all() + return latencies + + +class HTTPLatenciesBase(object, metaclass=abc.ABCMeta): + """ + HTTP Latency class + + """ + def _reset_all(self): + """Reset variables""" + self._split = [0] * MAX_LATENCY_BUCKET_COUNT + self._segment = [0] * MAX_LATENCY_BUCKET_COUNT + self._impression = [0] * MAX_LATENCY_BUCKET_COUNT + self._impression_count = [0] * MAX_LATENCY_BUCKET_COUNT + self._event = [0] * MAX_LATENCY_BUCKET_COUNT + self._telemetry = [0] * MAX_LATENCY_BUCKET_COUNT + self._token = [0] * MAX_LATENCY_BUCKET_COUNT + + @abc.abstractmethod + def add_latency(self, resource, latency): + """ + Add Latency method + """ + + @abc.abstractmethod + def pop_all(self): + """ + Pop all latencies + """ + + +class HTTPLatencies(HTTPLatenciesBase): + """ + HTTP Latency class + + """ + def __init__(self): + """Constructor""" + self._lock = threading.RLock() + with self._lock: + self._reset_all() + + def add_latency(self, resource, latency): + """ + Add Latency method + + :param resource: passed resource name + :type resource: str + :param latency: amount of latency in microseconds + :type latency: int + """ + latency_bucket = get_latency_bucket_index(latency) + with self._lock: + if resource == HTTPExceptionsAndLatencies.SPLIT: + self._split[latency_bucket] += 1 + elif resource == HTTPExceptionsAndLatencies.SEGMENT: + self._segment[latency_bucket] += 1 + elif resource == HTTPExceptionsAndLatencies.IMPRESSION: + self._impression[latency_bucket] += 1 + elif resource == HTTPExceptionsAndLatencies.IMPRESSION_COUNT: + self._impression_count[latency_bucket] += 1 + elif resource == HTTPExceptionsAndLatencies.EVENT: + self._event[latency_bucket] += 1 + elif resource == HTTPExceptionsAndLatencies.TELEMETRY: + self._telemetry[latency_bucket] += 1 + elif resource == HTTPExceptionsAndLatencies.TOKEN: + self._token[latency_bucket] += 1 + else: + return + + def pop_all(self): + """ + Pop all latencies + + :return: Dictonary of latencies + :rtype: dict + """ + with self._lock: + latencies = {HTTPExceptionsAndLatencies.HTTP_LATENCIES.value: {HTTPExceptionsAndLatencies.SPLIT.value: self._split, HTTPExceptionsAndLatencies.SEGMENT.value: self._segment, HTTPExceptionsAndLatencies.IMPRESSION.value: self._impression, + HTTPExceptionsAndLatencies.IMPRESSION_COUNT.value: self._impression_count, HTTPExceptionsAndLatencies.EVENT.value: self._event, + HTTPExceptionsAndLatencies.TELEMETRY.value: self._telemetry, HTTPExceptionsAndLatencies.TOKEN.value: self._token} + } + self._reset_all() + return latencies + + +class HTTPLatenciesAsync(HTTPLatenciesBase): + """ + HTTP Latency async class + + """ + @classmethod + async def create(cls): + """Constructor""" + self = cls() + self._lock = asyncio.Lock() + async with self._lock: + self._reset_all() + return self + + async def add_latency(self, resource, latency): + """ + Add Latency method + + :param resource: passed resource name + :type resource: str + :param latency: amount of latency in microseconds + :type latency: int + """ + latency_bucket = get_latency_bucket_index(latency) + async with self._lock: + if resource == HTTPExceptionsAndLatencies.SPLIT: + self._split[latency_bucket] += 1 + elif resource == HTTPExceptionsAndLatencies.SEGMENT: + self._segment[latency_bucket] += 1 + elif resource == HTTPExceptionsAndLatencies.IMPRESSION: + self._impression[latency_bucket] += 1 + elif resource == HTTPExceptionsAndLatencies.IMPRESSION_COUNT: + self._impression_count[latency_bucket] += 1 + elif resource == HTTPExceptionsAndLatencies.EVENT: + self._event[latency_bucket] += 1 + elif resource == HTTPExceptionsAndLatencies.TELEMETRY: + self._telemetry[latency_bucket] += 1 + elif resource == HTTPExceptionsAndLatencies.TOKEN: + self._token[latency_bucket] += 1 + else: + return + + async def pop_all(self): + """ + Pop all latencies + + :return: Dictonary of latencies + :rtype: dict + """ + async with self._lock: + latencies = {HTTPExceptionsAndLatencies.HTTP_LATENCIES.value: {HTTPExceptionsAndLatencies.SPLIT.value: self._split, HTTPExceptionsAndLatencies.SEGMENT.value: self._segment, HTTPExceptionsAndLatencies.IMPRESSION.value: self._impression, + HTTPExceptionsAndLatencies.IMPRESSION_COUNT.value: self._impression_count, HTTPExceptionsAndLatencies.EVENT.value: self._event, + HTTPExceptionsAndLatencies.TELEMETRY.value: self._telemetry, HTTPExceptionsAndLatencies.TOKEN.value: self._token} + } + self._reset_all() + return latencies + + +class MethodExceptionsBase(object, metaclass=abc.ABCMeta): + """ + Method exceptions base class + + """ + def _reset_all(self): + """Reset variables""" + self._treatment = 0 + self._treatments = 0 + self._treatment_with_config = 0 + self._treatments_with_config = 0 + self._treatments_by_flag_set = 0 + self._treatments_by_flag_sets = 0 + self._treatments_with_config_by_flag_set = 0 + self._treatments_with_config_by_flag_sets = 0 + self._track = 0 + + @abc.abstractmethod + def add_exception(self, method): + """ + Add exceptions method + """ + + @abc.abstractmethod + def pop_all(self): + """ + Pop all exceptions + """ + + +class MethodExceptions(MethodExceptionsBase): + """ + Method exceptions class + + """ + def __init__(self): + """Constructor""" + self._lock = threading.RLock() + with self._lock: + self._reset_all() + + def add_exception(self, method): + """ + Add exceptions method + + :param method: passed method name + :type method: str + """ + with self._lock: + if method == MethodExceptionsAndLatencies.TREATMENT: + self._treatment += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS: + self._treatments += 1 + elif method == MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG: + self._treatment_with_config += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG: + self._treatments_with_config += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET: + self._treatments_by_flag_set += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS: + self._treatments_by_flag_sets += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET: + self._treatments_with_config_by_flag_set += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS: + self._treatments_with_config_by_flag_sets += 1 + elif method == MethodExceptionsAndLatencies.TRACK: + self._track += 1 + else: + return + + def pop_all(self): + """ + Pop all exceptions + + :return: Dictonary of exceptions + :rtype: dict + """ + with self._lock: + exceptions = { + MethodExceptionsAndLatencies.METHOD_EXCEPTIONS.value: { + MethodExceptionsAndLatencies.TREATMENT.value: self._treatment, + MethodExceptionsAndLatencies.TREATMENTS.value: self._treatments, + MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG.value: self._treatment_with_config, + MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG.value: self._treatments_with_config, + MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET.value: self._treatments_by_flag_set, + MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS.value: self._treatments_by_flag_sets, + MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET.value: self._treatments_with_config_by_flag_set, + MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS.value: self._treatments_with_config_by_flag_sets, + MethodExceptionsAndLatencies.TRACK.value: self._track} + } + self._reset_all() + return exceptions + + +class MethodExceptionsAsync(MethodExceptionsBase): + """ + Method async exceptions class + + """ + @classmethod + async def create(cls): + """Constructor""" + self = cls() + self._lock = asyncio.Lock() + async with self._lock: + self._reset_all() + return self + + async def add_exception(self, method): + """ + Add exceptions method + + :param method: passed method name + :type method: str + """ + async with self._lock: + if method == MethodExceptionsAndLatencies.TREATMENT: + self._treatment += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS: + self._treatments += 1 + elif method == MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG: + self._treatment_with_config += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG: + self._treatments_with_config += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET: + self._treatments_by_flag_set += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS: + self._treatments_by_flag_sets += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET: + self._treatments_with_config_by_flag_set += 1 + elif method == MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS: + self._treatments_with_config_by_flag_sets += 1 + elif method == MethodExceptionsAndLatencies.TRACK: + self._track += 1 + else: + return + + async def pop_all(self): + """ + Pop all exceptions + + :return: Dictonary of exceptions + :rtype: dict + """ + async with self._lock: + exceptions = { + MethodExceptionsAndLatencies.METHOD_EXCEPTIONS.value: { + MethodExceptionsAndLatencies.TREATMENT.value: self._treatment, + MethodExceptionsAndLatencies.TREATMENTS.value: self._treatments, + MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG.value: self._treatment_with_config, + MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG.value: self._treatments_with_config, + MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET.value: self._treatments_by_flag_set, + MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS.value: self._treatments_by_flag_sets, + MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET.value: self._treatments_with_config_by_flag_set, + MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS.value: self._treatments_with_config_by_flag_sets, + MethodExceptionsAndLatencies.TRACK.value: self._track} + } + self._reset_all() + return exceptions + + +class LastSynchronizationBase(object, metaclass=abc.ABCMeta): + """ + Last Synchronization info base class + + """ + def _reset_all(self): + """Reset variables""" + self._split = 0 + self._segment = 0 + self._impression = 0 + self._impression_count = 0 + self._event = 0 + self._telemetry = 0 + self._token = 0 + + @abc.abstractmethod + def add_latency(self, resource, sync_time): + """ + Add Latency method + """ + + @abc.abstractmethod + def get_all(self): + """ + get all exceptions + """ + +class LastSynchronization(LastSynchronizationBase): + """ + Last Synchronization info class + + """ + def __init__(self): + """Constructor""" + self._lock = threading.RLock() + with self._lock: + self._reset_all() + + def add_latency(self, resource, sync_time): + """ + Add Latency method + + :param resource: passed resource name + :type resource: str + :param sync_time: amount of last sync time + :type sync_time: int + """ + with self._lock: + if resource == HTTPExceptionsAndLatencies.SPLIT: + self._split = sync_time + elif resource == HTTPExceptionsAndLatencies.SEGMENT: + self._segment = sync_time + elif resource == HTTPExceptionsAndLatencies.IMPRESSION: + self._impression = sync_time + elif resource == HTTPExceptionsAndLatencies.IMPRESSION_COUNT: + self._impression_count = sync_time + elif resource == HTTPExceptionsAndLatencies.EVENT: + self._event = sync_time + elif resource == HTTPExceptionsAndLatencies.TELEMETRY: + self._telemetry = sync_time + elif resource == HTTPExceptionsAndLatencies.TOKEN: + self._token = sync_time + else: + return + + def get_all(self): + """ + get all exceptions + + :return: Dictonary of latencies + :rtype: dict + """ + with self._lock: + return { + _LastSynchronizationConstants.LAST_SYNCHRONIZATIONS.value: { + HTTPExceptionsAndLatencies.SPLIT.value: self._split, + HTTPExceptionsAndLatencies.SEGMENT.value: self._segment, + HTTPExceptionsAndLatencies.IMPRESSION.value: self._impression, + HTTPExceptionsAndLatencies.IMPRESSION_COUNT.value: self._impression_count, + HTTPExceptionsAndLatencies.EVENT.value: self._event, + HTTPExceptionsAndLatencies.TELEMETRY.value: self._telemetry, + HTTPExceptionsAndLatencies.TOKEN.value: self._token} + } + +class LastSynchronizationAsync(LastSynchronizationBase): + """ + Last Synchronization async info class + + """ + @classmethod + async def create(cls): + """Constructor""" + self = cls() + self._lock = asyncio.Lock() + async with self._lock: + self._reset_all() + return self + + async def add_latency(self, resource, sync_time): + """ + Add Latency method + + :param resource: passed resource name + :type resource: str + :param sync_time: amount of last sync time + :type sync_time: int + """ + async with self._lock: + if resource == HTTPExceptionsAndLatencies.SPLIT: + self._split = sync_time + elif resource == HTTPExceptionsAndLatencies.SEGMENT: + self._segment = sync_time + elif resource == HTTPExceptionsAndLatencies.IMPRESSION: + self._impression = sync_time + elif resource == HTTPExceptionsAndLatencies.IMPRESSION_COUNT: + self._impression_count = sync_time + elif resource == HTTPExceptionsAndLatencies.EVENT: + self._event = sync_time + elif resource == HTTPExceptionsAndLatencies.TELEMETRY: + self._telemetry = sync_time + elif resource == HTTPExceptionsAndLatencies.TOKEN: + self._token = sync_time + else: + return + + async def get_all(self): + """ + get all exceptions + + :return: Dictonary of latencies + :rtype: dict + """ + async with self._lock: + return { + _LastSynchronizationConstants.LAST_SYNCHRONIZATIONS.value: { + HTTPExceptionsAndLatencies.SPLIT.value: self._split, + HTTPExceptionsAndLatencies.SEGMENT.value: self._segment, + HTTPExceptionsAndLatencies.IMPRESSION.value: self._impression, + HTTPExceptionsAndLatencies.IMPRESSION_COUNT.value: self._impression_count, + HTTPExceptionsAndLatencies.EVENT.value: self._event, + HTTPExceptionsAndLatencies.TELEMETRY.value: self._telemetry, + HTTPExceptionsAndLatencies.TOKEN.value: self._token} + } + + +class HTTPErrorsBase(object, metaclass=abc.ABCMeta): + """ + Http errors base class + + """ + def _reset_all(self): + """Reset variables""" + self._split = {} + self._segment = {} + self._impression = {} + self._impression_count = {} + self._event = {} + self._telemetry = {} + self._token = {} + + @abc.abstractmethod + def add_error(self, resource, status): + """ + Add Latency method + """ + + @abc.abstractmethod + def pop_all(self): + """ + Pop all errors + """ + + +class HTTPErrors(HTTPErrorsBase): + """ + Http errors class + + """ + def __init__(self): + """Constructor""" + self._lock = threading.RLock() + with self._lock: + self._reset_all() + + def add_error(self, resource, status): + """ + Add Latency method + + :param resource: passed resource name + :type resource: str + :param status: http error code + :type status: str + """ + status = str(status) + with self._lock: + if resource == HTTPExceptionsAndLatencies.SPLIT: + if status not in self._split: + self._split[status] = 0 + self._split[status] += 1 + elif resource == HTTPExceptionsAndLatencies.SEGMENT: + if status not in self._segment: + self._segment[status] = 0 + self._segment[status] += 1 + elif resource == HTTPExceptionsAndLatencies.IMPRESSION: + if status not in self._impression: + self._impression[status] = 0 + self._impression[status] += 1 + elif resource == HTTPExceptionsAndLatencies.IMPRESSION_COUNT: + if status not in self._impression_count: + self._impression_count[status] = 0 + self._impression_count[status] += 1 + elif resource == HTTPExceptionsAndLatencies.EVENT: + if status not in self._event: + self._event[status] = 0 + self._event[status] += 1 + elif resource == HTTPExceptionsAndLatencies.TELEMETRY: + if status not in self._telemetry: + self._telemetry[status] = 0 + self._telemetry[status] += 1 + elif resource == HTTPExceptionsAndLatencies.TOKEN: + if status not in self._token: + self._token[status] = 0 + self._token[status] += 1 + else: + return + + def pop_all(self): + """ + Pop all errors + + :return: Dictonary of exceptions + :rtype: dict + """ + with self._lock: + http_errors = { + HTTPExceptionsAndLatencies.HTTP_ERRORS.value: { + HTTPExceptionsAndLatencies.SPLIT.value: self._split, + HTTPExceptionsAndLatencies.SEGMENT.value: self._segment, + HTTPExceptionsAndLatencies.IMPRESSION.value: self._impression, + HTTPExceptionsAndLatencies.IMPRESSION_COUNT.value: self._impression_count, HTTPExceptionsAndLatencies.EVENT.value: self._event, + HTTPExceptionsAndLatencies.TELEMETRY.value: self._telemetry, HTTPExceptionsAndLatencies.TOKEN.value: self._token + } + } + self._reset_all() + return http_errors + + +class HTTPErrorsAsync(HTTPErrorsBase): + """ + Http error async class + + """ + @classmethod + async def create(cls): + """Constructor""" + self = cls() + self._lock = asyncio.Lock() + async with self._lock: + self._reset_all() + return self + + async def add_error(self, resource, status): + """ + Add Latency method + + :param resource: passed resource name + :type resource: str + :param status: http error code + :type status: str + """ + status = str(status) + async with self._lock: + if resource == HTTPExceptionsAndLatencies.SPLIT: + if status not in self._split: + self._split[status] = 0 + self._split[status] += 1 + elif resource == HTTPExceptionsAndLatencies.SEGMENT: + if status not in self._segment: + self._segment[status] = 0 + self._segment[status] += 1 + elif resource == HTTPExceptionsAndLatencies.IMPRESSION: + if status not in self._impression: + self._impression[status] = 0 + self._impression[status] += 1 + elif resource == HTTPExceptionsAndLatencies.IMPRESSION_COUNT: + if status not in self._impression_count: + self._impression_count[status] = 0 + self._impression_count[status] += 1 + elif resource == HTTPExceptionsAndLatencies.EVENT: + if status not in self._event: + self._event[status] = 0 + self._event[status] += 1 + elif resource == HTTPExceptionsAndLatencies.TELEMETRY: + if status not in self._telemetry: + self._telemetry[status] = 0 + self._telemetry[status] += 1 + elif resource == HTTPExceptionsAndLatencies.TOKEN: + if status not in self._token: + self._token[status] = 0 + self._token[status] += 1 + else: + return + + async def pop_all(self): + """ + Pop all errors + + :return: Dictonary of exceptions + :rtype: dict + """ + async with self._lock: + http_errors = { + HTTPExceptionsAndLatencies.HTTP_ERRORS.value: { + HTTPExceptionsAndLatencies.SPLIT.value: self._split, + HTTPExceptionsAndLatencies.SEGMENT.value: self._segment, + HTTPExceptionsAndLatencies.IMPRESSION.value: self._impression, + HTTPExceptionsAndLatencies.IMPRESSION_COUNT.value: self._impression_count, HTTPExceptionsAndLatencies.EVENT.value: self._event, + HTTPExceptionsAndLatencies.TELEMETRY.value: self._telemetry, HTTPExceptionsAndLatencies.TOKEN.value: self._token + } + } + self._reset_all() + return http_errors + + +class TelemetryCountersBase(object, metaclass=abc.ABCMeta): + """ + Counters base class + + """ + def _reset_all(self): + """Reset variables""" + self._impressions_queued = 0 + self._impressions_deduped = 0 + self._impressions_dropped = 0 + self._events_queued = 0 + self._events_dropped = 0 + self._auth_rejections = 0 + self._token_refreshes = 0 + self._session_length = 0 + self._update_from_sse = {} + + @abc.abstractmethod + def record_impressions_value(self, resource, value): + """ + Append to the resource value + """ + + @abc.abstractmethod + def record_events_value(self, resource, value): + """ + Append to the resource value + """ + + @abc.abstractmethod + def record_auth_rejections(self): + """ + Increament the auth rejection resource by one. + """ + + @abc.abstractmethod + def record_token_refreshes(self): + """ + Increament the token refreshes resource by one. + """ + + @abc.abstractmethod + def record_session_length(self, session): + """ + Set the session length value + """ + + @abc.abstractmethod + def get_counter_stats(self, resource): + """ + Get resource counter value + """ + + @abc.abstractmethod + def get_session_length(self): + """ + Get session length + """ + + @abc.abstractmethod + def pop_auth_rejections(self): + """ + Pop auth rejections + """ + + @abc.abstractmethod + def pop_token_refreshes(self): + """ + Pop token refreshes + """ + + +class TelemetryCounters(TelemetryCountersBase): + """ + Counters class + + """ + def __init__(self): + """Constructor""" + self._lock = threading.RLock() + with self._lock: + self._reset_all() + + def record_impressions_value(self, resource, value): + """ + Append to the resource value + + :param resource: passed resource name + :type resource: str + :param value: value to be appended + :type value: int + """ + with self._lock: + if resource == CounterConstants.IMPRESSIONS_QUEUED: + self._impressions_queued += value + elif resource == CounterConstants.IMPRESSIONS_DEDUPED: + self._impressions_deduped += value + elif resource == CounterConstants.IMPRESSIONS_DROPPED: + self._impressions_dropped += value + else: + return + + def record_events_value(self, resource, value): + """ + Append to the resource value + + :param resource: passed resource name + :type resource: str + :param value: value to be appended + :type value: int + """ + with self._lock: + if resource == CounterConstants.EVENTS_QUEUED: + self._events_queued += value + elif resource == CounterConstants.EVENTS_DROPPED: + self._events_dropped += value + else: + return + + def record_update_from_sse(self, event): + """ + Increment the update from sse resource by one. + """ + with self._lock: + if event.value not in self._update_from_sse: + self._update_from_sse[event.value] = 0 + self._update_from_sse[event.value] += 1 + + def record_auth_rejections(self): + """ + Increment the auth rejection resource by one. + + """ + with self._lock: + self._auth_rejections += 1 + + def record_token_refreshes(self): + """ + Increment the token refreshes resource by one. + + """ + with self._lock: + self._token_refreshes += 1 + + def pop_update_from_sse(self, event): + """ + Pop update from sse + :return: update from sse value + :rtype: int + """ + with self._lock: + if self._update_from_sse.get(event.value) is None: + return 0 + + update_from_sse = self._update_from_sse[event.value] + self._update_from_sse[event.value] = 0 + return update_from_sse + + def record_session_length(self, session): + """ + Set the session length value + + :param session: value to be set + :type session: int + """ + with self._lock: + self._session_length = session + + def get_counter_stats(self, resource): + """ + Get resource counter value + + :param resource: passed resource name + :type resource: str + + :return: resource value + :rtype: int + """ + + with self._lock: + if resource == CounterConstants.IMPRESSIONS_QUEUED: + return self._impressions_queued + + elif resource == CounterConstants.IMPRESSIONS_DEDUPED: + return self._impressions_deduped + + elif resource == CounterConstants.IMPRESSIONS_DROPPED: + return self._impressions_dropped + + elif resource == CounterConstants.EVENTS_QUEUED: + return self._events_queued + + elif resource == CounterConstants.EVENTS_DROPPED: + return self._events_dropped + + else: + return 0 + + def get_session_length(self): + """ + Get session length + + :return: session length value + :rtype: int + """ + with self._lock: + return self._session_length + + def pop_auth_rejections(self): + """ + Pop auth rejections + + :return: auth rejections value + :rtype: int + """ + with self._lock: + auth_rejections = self._auth_rejections + self._auth_rejections = 0 + return auth_rejections + + def pop_token_refreshes(self): + """ + Pop token refreshes + + :return: token refreshes value + :rtype: int + """ + with self._lock: + token_refreshes = self._token_refreshes + self._token_refreshes = 0 + return token_refreshes + +class TelemetryCountersAsync(TelemetryCountersBase): + """ + Counters async class + + """ + @classmethod + async def create(cls): + """Constructor""" + self = cls() + self._lock = asyncio.Lock() + async with self._lock: + self._reset_all() + return self + + async def record_impressions_value(self, resource, value): + """ + Append to the resource value + + :param resource: passed resource name + :type resource: str + :param value: value to be appended + :type value: int + """ + async with self._lock: + if resource == CounterConstants.IMPRESSIONS_QUEUED: + self._impressions_queued += value + elif resource == CounterConstants.IMPRESSIONS_DEDUPED: + self._impressions_deduped += value + elif resource == CounterConstants.IMPRESSIONS_DROPPED: + self._impressions_dropped += value + else: + return + + async def record_events_value(self, resource, value): + """ + Append to the resource value + + :param resource: passed resource name + :type resource: str + :param value: value to be appended + :type value: int + """ + async with self._lock: + if resource == CounterConstants.EVENTS_QUEUED: + self._events_queued += value + elif resource == CounterConstants.EVENTS_DROPPED: + self._events_dropped += value + else: + return + + async def record_update_from_sse(self, event): + """ + Increment the update from sse resource by one. + """ + async with self._lock: + if event.value not in self._update_from_sse: + self._update_from_sse[event.value] = 0 + self._update_from_sse[event.value] += 1 + + async def record_auth_rejections(self): + """ + Increment the auth rejection resource by one. + + """ + async with self._lock: + self._auth_rejections += 1 + + async def record_token_refreshes(self): + """ + Increment the token refreshes resource by one. + + """ + async with self._lock: + self._token_refreshes += 1 + + async def pop_update_from_sse(self, event): + """ + Pop update from sse + :return: update from sse value + :rtype: int + """ + async with self._lock: + if self._update_from_sse.get(event.value) is None: + return 0 + + update_from_sse = self._update_from_sse[event.value] + self._update_from_sse[event.value] = 0 + return update_from_sse + + async def record_session_length(self, session): + """ + Set the session length value + + :param session: value to be set + :type session: int + """ + async with self._lock: + self._session_length = session + + async def get_counter_stats(self, resource): + """ + Get resource counter value + + :param resource: passed resource name + :type resource: str + + :return: resource value + :rtype: int + """ + async with self._lock: + if resource == CounterConstants.IMPRESSIONS_QUEUED: + return self._impressions_queued + + elif resource == CounterConstants.IMPRESSIONS_DEDUPED: + return self._impressions_deduped + + elif resource == CounterConstants.IMPRESSIONS_DROPPED: + return self._impressions_dropped + + elif resource == CounterConstants.EVENTS_QUEUED: + return self._events_queued + + elif resource == CounterConstants.EVENTS_DROPPED: + return self._events_dropped + + else: + return 0 + + async def get_session_length(self): + """ + Get session length + + :return: session length value + :rtype: int + """ + async with self._lock: + return self._session_length + + async def pop_auth_rejections(self): + """ + Pop auth rejections + + :return: auth rejections value + :rtype: int + """ + async with self._lock: + auth_rejections = self._auth_rejections + self._auth_rejections = 0 + return auth_rejections + + async def pop_token_refreshes(self): + """ + Pop token refreshes + + :return: token refreshes value + :rtype: int + """ + async with self._lock: + token_refreshes = self._token_refreshes + self._token_refreshes = 0 + return token_refreshes + + +class StreamingEvent(object): + """ + Streaming event class + + """ + def __init__(self, streaming_event): + """ + Constructor + + :param streaming_event: Streaming event tuple: ('type', 'data', 'time') + :type streaming_event: dict + """ + self._type = streaming_event[0].value + self._data = streaming_event[1] + self._time = streaming_event[2] + + @property + def type(self): + """ + Get streaming event type + + :return: streaming event type + :rtype: str + """ + return self._type + + @property + def data(self): + """ + Get streaming event data + + :return: streaming event data + :rtype: str + """ + return self._data + + @property + def time(self): + """ + Get streaming event time + + :return: streaming event time + :rtype: int + """ + return self._time + +class StreamingEventsAsync(object): + """ + Streaming events async class + + """ + @classmethod + async def create(cls): + """Constructor""" + self = cls() + self._lock = asyncio.Lock() + async with self._lock: + self._streaming_events = [] + return self + + async def record_streaming_event(self, streaming_event): + """ + Record new streaming event + + :param streaming_event: Streaming event dict: + {'type': string, 'data': string, 'time': string} + :type streaming_event: dict + """ + if not StreamingEvent(streaming_event): + return + async with self._lock: + if len(self._streaming_events) < MAX_STREAMING_EVENTS: + self._streaming_events.append(StreamingEvent(streaming_event)) + + async def pop_streaming_events(self): + """ + Get and reset streaming events + + :return: streaming events dict + :rtype: dict + """ + async with self._lock: + streaming_events = self._streaming_events + self._streaming_events = [] + return {_StreamingEventsConstant.STREAMING_EVENTS.value: [ + {'e': streaming_event.type, 'd': streaming_event.data, + 't': streaming_event.time} for streaming_event in streaming_events]} + +class StreamingEvents(object): + """ + Streaming events class + + """ + def __init__(self): + """Constructor""" + self._lock = threading.RLock() + with self._lock: + self._streaming_events = [] + + def record_streaming_event(self, streaming_event): + """ + Record new streaming event + + :param streaming_event: Streaming event dict: + {'type': string, 'data': string, 'time': string} + :type streaming_event: dict + """ + if not StreamingEvent(streaming_event): + return + with self._lock: + if len(self._streaming_events) < MAX_STREAMING_EVENTS: + self._streaming_events.append(StreamingEvent(streaming_event)) + + def pop_streaming_events(self): + """ + Get and reset streaming events + + :return: streaming events dict + :rtype: dict + """ + + with self._lock: + streaming_events = self._streaming_events + self._streaming_events = [] + return {_StreamingEventsConstant.STREAMING_EVENTS.value: [ + {'e': streaming_event.type, 'd': streaming_event.data, + 't': streaming_event.time} for streaming_event in streaming_events]} + + +class TelemetryConfigBase(object, metaclass=abc.ABCMeta): + """ + Telemetry init config base class + + """ + def _reset_all(self): + """Reset variables""" + self._block_until_ready_timeout = 0 + self._not_ready = 0 + self._time_until_ready = 0 + self._operation_mode = None + self._storage_type = None + self._streaming_enabled = None + self._refresh_rate = { + _ConfigParams.SPLITS_REFRESH_RATE.value: 0, + _ConfigParams.SEGMENTS_REFRESH_RATE.value: 0, + _ConfigParams.IMPRESSIONS_REFRESH_RATE.value: 0, + _ConfigParams.EVENTS_REFRESH_RATE.value: 0, + _ConfigParams.TELEMETRY_REFRESH_RATE.value: 0} + self._url_override = { + _ApiURLs.SDK_URL.value: False, + _ApiURLs.EVENTS_URL.value: False, + _ApiURLs.AUTH_URL.value: False, + _ApiURLs.STREAMING_URL.value: False, + _ApiURLs.TELEMETRY_URL.value: False} + self._impressions_queue_size = 0 + self._events_queue_size = 0 + self._impressions_mode = None + self._impression_listener = False + self._http_proxy = None + self._active_factory_count = 0 + self._redundant_factory_count = 0 + self._flag_sets = 0 + self._flag_sets_invalid = 0 + + @abc.abstractmethod + def record_config(self, config, extra_config, total_flag_sets, invalid_flag_sets): + """ + Record configurations. + """ + + @abc.abstractmethod + def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """ + Record active and redundant factories counts + """ + + @abc.abstractmethod + def record_ready_time(self, ready_time): + """ + Record ready time. + """ + + @abc.abstractmethod + def record_bur_time_out(self): + """ + Record block until ready timeout count + """ + + @abc.abstractmethod + def record_not_ready_usage(self): + """ + record non-ready usage count + """ + + @abc.abstractmethod + def get_bur_time_outs(self): + """ + Get block until ready timeout. + """ + + @abc.abstractmethod + def get_non_ready_usage(self): + """ + Get non-ready usage. + """ + + @abc.abstractmethod + def get_stats(self): + """ + Get config stats. + """ + + def _get_operation_mode(self, op_mode): + """ + Get formatted operation mode + + :param op_mode: config operation mode + :type config: str + + :return: operation mode + :rtype: int + """ + if op_mode == OperationMode.STANDALONE.value: + return 0 + + elif op_mode == OperationMode.CONSUMER.value: + return 1 + + else: + return 2 + + def _get_storage_type(self, op_mode, st_type): + """ + Get storage type from operation mode + + :param op_mode: config operation mode + :type config: str + + :return: storage type + :rtype: str + """ + if op_mode == OperationMode.STANDALONE.value: + return StorageType.MEMORY.value + + elif st_type == StorageType.REDIS.value: + return StorageType.REDIS.value + + else: + return StorageType.PLUGGABLE.value + + def _get_refresh_rates(self, config): + """ + Get refresh rates within config dict + + :param config: config dict + :type config: dict + + :return: refresh rates + :rtype: RefreshRates object + """ + return { + _ConfigParams.SPLITS_REFRESH_RATE.value: config[_ConfigParams.SPLITS_REFRESH_RATE.value], + _ConfigParams.SEGMENTS_REFRESH_RATE.value: config[_ConfigParams.SEGMENTS_REFRESH_RATE.value], + _ConfigParams.IMPRESSIONS_REFRESH_RATE.value: config[_ConfigParams.IMPRESSIONS_REFRESH_RATE.value], + _ConfigParams.EVENTS_REFRESH_RATE.value: config[_ConfigParams.EVENTS_REFRESH_RATE.value], + _ConfigParams.TELEMETRY_REFRESH_RATE.value: config[_ConfigParams.TELEMETRY_REFRESH_RATE.value] + } + + def _get_url_overrides(self, config): + """ + Get URL override within the config dict. + + :param config: config dict + :type config: dict + + :return: URL overrides dict + :rtype: URLOverrides object + """ + return { + _ApiURLs.SDK_URL.value: True if _ApiURLs.SDK_URL.value in config else False, + _ApiURLs.EVENTS_URL.value: True if _ApiURLs.EVENTS_URL.value in config else False, + _ApiURLs.AUTH_URL.value: True if _ApiURLs.AUTH_URL.value in config else False, + _ApiURLs.STREAMING_URL.value: True if _ApiURLs.STREAMING_URL.value in config else False, + _ApiURLs.TELEMETRY_URL.value: True if _ApiURLs.TELEMETRY_URL.value in config else False + } + + def _get_impressions_mode(self, imp_mode): + """ + Get impressions mode from operation mode + + :param op_mode: config operation mode + :type config: str + + :return: impressions mode + :rtype: int + """ + if imp_mode == ImpressionsMode.DEBUG.value: + return 1 + + elif imp_mode == ImpressionsMode.OPTIMIZED.value: + return 0 + + else: + return 2 + + def _check_if_proxy_detected(self): + """ + Return boolean flag if network https proxy is detected + + :return: https network proxy flag + :rtype: boolean + """ + for x in os.environ: + if x.upper() == _ExtraConfig.HTTPS_PROXY_ENV.value: + return True + + return False + + +class TelemetryConfig(TelemetryConfigBase): + """ + Telemetry init config class + + """ + def __init__(self): + """Constructor""" + self._lock = threading.RLock() + with self._lock: + self._reset_all() + + def record_config(self, config, extra_config, total_flag_sets, invalid_flag_sets): + """ + Record configurations. + + :param config: config dict: { + 'operationMode': int, 'storageType': string, 'streamingEnabled': boolean, + 'refreshRate' : { + 'featuresRefreshRate': int, + 'segmentsRefreshRate': int, + 'impressionsRefreshRate': int, + 'eventsPushRate': int, + 'metricsRefreshRate': int + } + 'urlOverride' : { + 'sdk_url': boolean, 'events_url': boolean, 'auth_url': boolean, + 'streaming_url': boolean, 'telemetry_url': boolean, } + }, + 'impressionsQueueSize': int, 'eventsQueueSize': int, 'impressionsMode': string, + 'impressionsListener': boolean, 'activeFactoryCount': int, 'redundantFactoryCount': int + } + :type config: dict + """ + with self._lock: + self._operation_mode = self._get_operation_mode(config[_ConfigParams.OPERATION_MODE.value]) + self._storage_type = self._get_storage_type(config[_ConfigParams.OPERATION_MODE.value], config[_ConfigParams.STORAGE_TYPE.value]) + self._streaming_enabled = config[_ConfigParams.STREAMING_ENABLED.value] + self._refresh_rate = self._get_refresh_rates(config) + self._url_override = self._get_url_overrides(extra_config) + self._impressions_queue_size = config[_ConfigParams.IMPRESSIONS_QUEUE_SIZE.value] + self._events_queue_size = config[_ConfigParams.EVENTS_QUEUE_SIZE.value] + self._impressions_mode = self._get_impressions_mode(config[_ConfigParams.IMPRESSIONS_MODE.value]) + self._impression_listener = True if config[_ConfigParams.IMPRESSIONS_LISTENER.value] is not None else False + self._http_proxy = self._check_if_proxy_detected() + self._flag_sets = total_flag_sets + self._flag_sets_invalid = invalid_flag_sets + + def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """ + Record active and redundant factories counts + + :param active_factory_count: active factories count + :type active_factory_count: int + + :param redundant_factory_count: redundant factories count + :type redundant_factory_count: int + """ + with self._lock: + self._active_factory_count = active_factory_count + self._redundant_factory_count = redundant_factory_count + + def record_ready_time(self, ready_time): + """ + Record ready time. + + :param ready_time: SDK ready time + :type ready_time: int + """ + with self._lock: + self._time_until_ready = ready_time + + def record_bur_time_out(self): + """ + Record block until ready timeout count + + """ + with self._lock: + self._block_until_ready_timeout += 1 + + def record_not_ready_usage(self): + """ + record non-ready usage count + + """ + with self._lock: + self._not_ready += 1 + + def get_bur_time_outs(self): + """ + Get block until ready timeout. + + :return: block until ready timeouts count + :rtype: int + """ + with self._lock: + return self._block_until_ready_timeout + + def get_non_ready_usage(self): + """ + Get non-ready usage. + + :return: non-ready usage count + :rtype: int + """ + with self._lock: + return self._not_ready + + def get_stats(self): + """ + Get config stats. + + :return: dict of all config stats. + :rtype: dict + """ + with self._lock: + return { + 'bT': self._block_until_ready_timeout, + 'nR': self._not_ready, + 'tR': self._time_until_ready, + 'oM': self._operation_mode, + 'sT': self._storage_type, + 'sE': self._streaming_enabled, + 'rR': { + 'sp': self._refresh_rate[_ConfigParams.SPLITS_REFRESH_RATE.value], + 'se': self._refresh_rate[_ConfigParams.SEGMENTS_REFRESH_RATE.value], + 'im': self._refresh_rate[_ConfigParams.IMPRESSIONS_REFRESH_RATE.value], + 'ev': self._refresh_rate[_ConfigParams.EVENTS_REFRESH_RATE.value], + 'te': self._refresh_rate[_ConfigParams.TELEMETRY_REFRESH_RATE.value]}, + 'uO': { + 's': self._url_override[_ApiURLs.SDK_URL.value], + 'e': self._url_override[_ApiURLs.EVENTS_URL.value], + 'a': self._url_override[_ApiURLs.AUTH_URL.value], + 'st': self._url_override[_ApiURLs.STREAMING_URL.value], + 't': self._url_override[_ApiURLs.TELEMETRY_URL.value]}, + 'iQ': self._impressions_queue_size, + 'eQ': self._events_queue_size, + 'iM': self._impressions_mode, + 'iL': self._impression_listener, + 'hp': self._http_proxy, + 'aF': self._active_factory_count, + 'rF': self._redundant_factory_count, + 'fsT': self._flag_sets, + 'fsI': self._flag_sets_invalid + } + + +class TelemetryConfigAsync(TelemetryConfigBase): + """ + Telemetry init config async class + + """ + @classmethod + async def create(cls): + """Constructor""" + self = cls() + self._lock = asyncio.Lock() + async with self._lock: + self._reset_all() + return self + + async def record_config(self, config, extra_config, total_flag_sets, invalid_flag_sets): + """ + Record configurations. + + :param config: config dict: { + 'operationMode': int, 'storageType': string, 'streamingEnabled': boolean, + 'refreshRate' : { + 'featuresRefreshRate': int, + 'segmentsRefreshRate': int, + 'impressionsRefreshRate': int, + 'eventsPushRate': int, + 'metricsRefreshRate': int + } + 'urlOverride' : { + 'sdk_url': boolean, 'events_url': boolean, 'auth_url': boolean, + 'streaming_url': boolean, 'telemetry_url': boolean, } + }, + 'impressionsQueueSize': int, 'eventsQueueSize': int, 'impressionsMode': string, + 'impressionsListener': boolean, 'activeFactoryCount': int, 'redundantFactoryCount': int + } + :type config: dict + """ + async with self._lock: + self._operation_mode = self._get_operation_mode(config[_ConfigParams.OPERATION_MODE.value]) + self._storage_type = self._get_storage_type(config[_ConfigParams.OPERATION_MODE.value], config[_ConfigParams.STORAGE_TYPE.value]) + self._streaming_enabled = config[_ConfigParams.STREAMING_ENABLED.value] + self._refresh_rate = self._get_refresh_rates(config) + self._url_override = self._get_url_overrides(extra_config) + self._impressions_queue_size = config[_ConfigParams.IMPRESSIONS_QUEUE_SIZE.value] + self._events_queue_size = config[_ConfigParams.EVENTS_QUEUE_SIZE.value] + self._impressions_mode = self._get_impressions_mode(config[_ConfigParams.IMPRESSIONS_MODE.value]) + self._impression_listener = True if config[_ConfigParams.IMPRESSIONS_LISTENER.value] is not None else False + self._http_proxy = self._check_if_proxy_detected() + self._flag_sets = total_flag_sets + self._flag_sets_invalid = invalid_flag_sets + + async def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """ + Record active and redundant factories counts + + :param active_factory_count: active factories count + :type active_factory_count: int + + :param redundant_factory_count: redundant factories count + :type redundant_factory_count: int + """ + async with self._lock: + self._active_factory_count = active_factory_count + self._redundant_factory_count = redundant_factory_count + + async def record_ready_time(self, ready_time): + """ + Record ready time. + + :param ready_time: SDK ready time + :type ready_time: int + """ + async with self._lock: + self._time_until_ready = ready_time + + async def record_bur_time_out(self): + """ + Record block until ready timeout count + + """ + async with self._lock: + self._block_until_ready_timeout += 1 + + async def record_not_ready_usage(self): + """ + record non-ready usage count + + """ + async with self._lock: + self._not_ready += 1 + + async def get_bur_time_outs(self): + """ + Get block until ready timeout. + + :return: block until ready timeouts count + :rtype: int + """ + async with self._lock: + return self._block_until_ready_timeout + + async def get_non_ready_usage(self): + """ + Get non-ready usage. + + :return: non-ready usage count + :rtype: int + """ + async with self._lock: + return self._not_ready + + async def get_stats(self): + """ + Get config stats. + + :return: dict of all config stats. + :rtype: dict + """ + async with self._lock: + return { + 'bT': self._block_until_ready_timeout, + 'nR': self._not_ready, + 'tR': self._time_until_ready, + 'oM': self._operation_mode, + 'sT': self._storage_type, + 'sE': self._streaming_enabled, + 'rR': { + 'sp': self._refresh_rate[_ConfigParams.SPLITS_REFRESH_RATE.value], + 'se': self._refresh_rate[_ConfigParams.SEGMENTS_REFRESH_RATE.value], + 'im': self._refresh_rate[_ConfigParams.IMPRESSIONS_REFRESH_RATE.value], + 'ev': self._refresh_rate[_ConfigParams.EVENTS_REFRESH_RATE.value], + 'te': self._refresh_rate[_ConfigParams.TELEMETRY_REFRESH_RATE.value]}, + 'uO': { + 's': self._url_override[_ApiURLs.SDK_URL.value], + 'e': self._url_override[_ApiURLs.EVENTS_URL.value], + 'a': self._url_override[_ApiURLs.AUTH_URL.value], + 'st': self._url_override[_ApiURLs.STREAMING_URL.value], + 't': self._url_override[_ApiURLs.TELEMETRY_URL.value]}, + 'iQ': self._impressions_queue_size, + 'eQ': self._events_queue_size, + 'iM': self._impressions_mode, + 'iL': self._impression_listener, + 'hp': self._http_proxy, + 'aF': self._active_factory_count, + 'rF': self._redundant_factory_count, + 'fsT': self._flag_sets, + 'fsI': self._flag_sets_invalid + } \ No newline at end of file diff --git a/splitio/models/token.py b/splitio/models/token.py index 33c4f48c..f2b0cf9c 100644 --- a/splitio/models/token.py +++ b/splitio/models/token.py @@ -58,25 +58,6 @@ def iat(self): return self._iat -def decode_token(raw_token): - """Decode token""" - if not 'pushEnabled' in raw_token or not 'token' in raw_token: - return None, None, None - - token = raw_token['token'] - push_enabled = raw_token['pushEnabled'] - if not push_enabled or len(token.strip()) == 0: - return None, None, None - - token_parts = token.split('.') - if len(token_parts) < 2: - return None, None, None - - to_decode = token_parts[1] - decoded_payload = base64.b64decode(to_decode + '='*(-len(to_decode) % 4)) - return push_enabled, token, json.loads(decoded_payload) - - def from_raw(raw_token): """ Parse a new token from a raw token response. @@ -87,5 +68,16 @@ def from_raw(raw_token): :return: New token model object :rtype: splitio.models.token.Token """ - push_enabled, token, decoded_token = decode_token(raw_token) - return None if push_enabled is None else Token(push_enabled, token, json.loads(decoded_token['x-ably-capability']), decoded_token['exp'], decoded_token['iat']) + if not 'pushEnabled' in raw_token or not 'token' in raw_token: + return Token(False, None, None, None, None) + + token = raw_token['token'] + push_enabled = raw_token['pushEnabled'] + token_parts = token.strip().split('.') + + if not push_enabled or len(token_parts) < 2: + return Token(False, None, None, None, None) + + to_decode = token_parts[1] + decoded_token = json.loads(base64.b64decode(to_decode + '='*(-len(to_decode) % 4))) + return Token(push_enabled, token, json.loads(decoded_token['x-ably-capability']), decoded_token['exp'], decoded_token['iat']) \ No newline at end of file diff --git a/splitio/optional/__init__.py b/splitio/optional/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/splitio/optional/loaders.py b/splitio/optional/loaders.py new file mode 100644 index 00000000..b5f11621 --- /dev/null +++ b/splitio/optional/loaders.py @@ -0,0 +1,30 @@ +import sys +try: + import asyncio + import aiohttp + import aiofiles +except ImportError: + def missing_asyncio_dependencies(*_, **__): + """Fail if missing dependencies are used.""" + raise NotImplementedError( + 'Missing aiohttp dependency. ' + 'Please use `pip install splitio_client[asyncio]` to install the sdk with asyncio support' + ) + aiohttp = missing_asyncio_dependencies + asyncio = missing_asyncio_dependencies + aiofiles = missing_asyncio_dependencies + +try: + from requests_kerberos import HTTPKerberosAuth, OPTIONAL +except ImportError: + def missing_auth_dependencies(*_, **__): + """Fail if missing dependencies are used.""" + raise NotImplementedError( + 'Missing kerberos auth dependency. ' + 'Please use `pip install splitio_client[kerberos]` to install the sdk with kerberos auth support' + ) + HTTPKerberosAuth = missing_auth_dependencies + OPTIONAL = missing_auth_dependencies + +async def _anext(it): + return await it.__anext__() diff --git a/splitio/push/__init__.py b/splitio/push/__init__.py index e69de29b..a7a9b624 100644 --- a/splitio/push/__init__.py +++ b/splitio/push/__init__.py @@ -0,0 +1,13 @@ +class AuthException(Exception): + """Exception to raise when an API call fails.""" + + def __init__(self, custom_message, status_code=None): + """Constructor.""" + Exception.__init__(self, custom_message) + +class SplitStorageException(Exception): + """Exception to raise when an API call fails.""" + + def __init__(self, custom_message, status_code=None): + """Constructor.""" + Exception.__init__(self, custom_message) diff --git a/splitio/push/manager.py b/splitio/push/manager.py index fb75464b..2046d610 100644 --- a/splitio/push/manager.py +++ b/splitio/push/manager.py @@ -1,26 +1,51 @@ """Push subsystem manager class and helpers.""" - import logging from threading import Timer +import abc +import sys +from splitio.optional.loaders import asyncio from splitio.api import APIException -from splitio.push.splitsse import SplitSSEClient +from splitio.util.time import get_current_epoch_time_ms +from splitio.push import AuthException +from splitio.push.splitsse import SplitSSEClient, SplitSSEClientAsync +from splitio.push.sse import SSE_EVENT_ERROR from splitio.push.parser import parse_incoming_event, EventParsingException, EventType, \ MessageType -from splitio.push.processor import MessageProcessor -from splitio.push.status_tracker import PushStatusTracker, Status +from splitio.push.processor import MessageProcessor, MessageProcessorAsync +from splitio.push.status_tracker import PushStatusTracker, Status, PushStatusTrackerAsync +from splitio.models.telemetry import StreamingEventTypes +if sys.version_info.major == 3 and sys.version_info.minor < 10: + from splitio.optional.loaders import _anext as anext _TOKEN_REFRESH_GRACE_PERIOD = 10 * 60 # 10 minutes - _LOGGER = logging.getLogger(__name__) +class PushManagerBase(object, metaclass=abc.ABCMeta): + """Worker template.""" + + @abc.abstractmethod + def update_workers_status(self, enabled): + """Enable/Disable push update workers.""" + + @abc.abstractmethod + def start(self): + """Start a new connection if not already running.""" + + @abc.abstractmethod + def stop(self, blocking=False): + """Stop the current ongoing connection.""" + + def _get_time_period(self, token): + return (token.exp - token.iat) - _TOKEN_REFRESH_GRACE_PERIOD + -class PushManager(object): # pylint:disable=too-many-instance-attributes +class PushManager(PushManagerBase): # pylint:disable=too-many-instance-attributes """Push notifications susbsytem manager.""" - def __init__(self, auth_api, synchronizer, feedback_loop, sdk_metadata, sse_url=None, client_key=None): + def __init__(self, auth_api, synchronizer, feedback_loop, sdk_metadata, telemetry_runtime_producer, sse_url=None, client_key=None): """ Class constructor. @@ -36,6 +61,9 @@ def __init__(self, auth_api, synchronizer, feedback_loop, sdk_metadata, sse_url= :param sdk_metadata: SDK version & machine name & IP. :type sdk_metadata: splitio.client.util.SdkMetadata + :param telemetry_runtime_producer: Telemetry object to record runtime events + :type sdk_metadata: splitio.engine.telemetry.TelemetryRunTimeProducer + :param sse_url: streaming base url. :type sse_url: str @@ -44,8 +72,8 @@ def __init__(self, auth_api, synchronizer, feedback_loop, sdk_metadata, sse_url= """ self._auth_api = auth_api self._feedback_loop = feedback_loop - self._processor = MessageProcessor(synchronizer) - self._status_tracker = PushStatusTracker() + self._processor = MessageProcessor(synchronizer, telemetry_runtime_producer) + self._status_tracker = PushStatusTracker(telemetry_runtime_producer) self._event_handlers = { EventType.MESSAGE: self._handle_message, EventType.ERROR: self._handle_error @@ -62,6 +90,8 @@ def __init__(self, auth_api, synchronizer, feedback_loop, sdk_metadata, sse_url= self._handle_connection_end, client_key, **kwargs) self._running = False self._next_refresh = Timer(0, lambda: 0) + self._telemetry_runtime_producer = telemetry_runtime_producer + def update_workers_status(self, enabled): """ @@ -141,16 +171,17 @@ def _trigger_connection_flow(self): self._feedback_loop.put(Status.PUSH_RETRYABLE_ERROR) return - if not token.push_enabled: + if token is None or not token.push_enabled: self._feedback_loop.put(Status.PUSH_NONRETRYABLE_ERROR) return - + self._telemetry_runtime_producer.record_token_refreshes() _LOGGER.debug("auth token fetched. connecting to streaming.") self._status_tracker.reset() if self._sse_client.start(token): _LOGGER.debug("connected to streaming, scheduling next refresh") self._setup_next_token_refresh(token) self._running = True + self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.CONNECTION_ESTABLISHED, 0, get_current_epoch_time_ms())) def _setup_next_token_refresh(self, token): """ @@ -161,10 +192,10 @@ def _setup_next_token_refresh(self, token): """ if self._next_refresh is not None: self._next_refresh.cancel() - self._next_refresh = Timer((token.exp - token.iat) - _TOKEN_REFRESH_GRACE_PERIOD, - self._token_refresh) + self._next_refresh = Timer(self._get_time_period(token), self._token_refresh) self._next_refresh.setName('TokenRefresh') self._next_refresh.start() + self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.TOKEN_REFRESH, 1000 * token.exp, get_current_epoch_time_ms())) def _handle_message(self, event): """ @@ -242,3 +273,267 @@ def _handle_connection_end(self): feedback = self._status_tracker.handle_disconnect() if feedback is not None: self._feedback_loop.put(feedback) + + +class PushManagerAsync(PushManagerBase): # pylint:disable=too-many-instance-attributes + """Push notifications susbsytem manager.""" + + def __init__(self, auth_api, synchronizer, feedback_loop, sdk_metadata, telemetry_runtime_producer, sse_url=None, client_key=None): + """ + Class constructor. + + :param auth_api: sdk-auth-service api client + :type auth_api: splitio.api.auth.AuthAPI + + :param synchronizer: split data synchronizer facade + :type synchronizer: splitio.sync.synchronizer.Synchronizer + + :param feedback_loop: queue where push status updates are published. + :type feedback_loop: queue.Queue + + :param sdk_metadata: SDK version & machine name & IP. + :type sdk_metadata: splitio.client.util.SdkMetadata + + :param telemetry_runtime_producer: Telemetry object to record runtime events + :type sdk_metadata: splitio.engine.telemetry.TelemetryRunTimeProducer + + :param sse_url: streaming base url. + :type sse_url: str + + :param client_key: client key. + :type client_key: str + """ + self._auth_api = auth_api + self._feedback_loop = feedback_loop + self._processor = MessageProcessorAsync(synchronizer, telemetry_runtime_producer) + self._status_tracker = PushStatusTrackerAsync(telemetry_runtime_producer) + self._event_handlers = { + EventType.MESSAGE: self._handle_message, + EventType.ERROR: self._handle_error + } + + self._message_handlers = { + MessageType.UPDATE: self._handle_update, + MessageType.CONTROL: self._handle_control, + MessageType.OCCUPANCY: self._handle_occupancy + } + + kwargs = {} if sse_url is None else {'base_url': sse_url} + self._sse_client = SplitSSEClientAsync(sdk_metadata, client_key, **kwargs) + self._running = False + self._telemetry_runtime_producer = telemetry_runtime_producer + self._token_task = None + + async def update_workers_status(self, enabled): + """ + Enable/Disable push update workers. + + :param enabled: if True, enable workers. If False, disable them. + :type enabled: bool + """ + await self._processor.update_workers_status(enabled) + + def start(self): + """Start a new connection if not already running.""" + if self._running: + _LOGGER.warning('Push manager already has a connection running. Ignoring') + return + + self._running_task = asyncio.get_running_loop().create_task(self._trigger_connection_flow()) + + async def stop(self, blocking=False): + """ + Stop the current ongoing connection. + + :param blocking: whether to wait for the connection to be successfully closed or not + :type blocking: bool + """ + if not self._running: + _LOGGER.warning('Push manager does not have an open SSE connection. Ignoring') + return + + if self._token_task: + self._token_task.cancel() + self._token_task = None + + if blocking: + await self._stop_current_conn() + else: + asyncio.get_running_loop().create_task(self._stop_current_conn()) + + async def close_sse_http_client(self): + await self._sse_client.close_sse_http_client() + + async def _event_handler(self, event): + """ + Process an incoming event. + + :param event: Incoming event + :type event: splitio.push.sse.SSEEvent + """ + parsed = None + try: + parsed = parse_incoming_event(event) + handle = self._event_handlers[parsed.event_type] + except Exception: + _LOGGER.error('Parsing exception or no handler for message of type %s', parsed.event_type if parsed else 'unknown') + _LOGGER.debug(str(event), exc_info=True) + return + + try: + await handle(parsed) + except Exception: # pylint:disable=broad-except + event_type = "unknown" if parsed is None else parsed.event_type + _LOGGER.error('something went wrong when processing message of type %s', event_type) + _LOGGER.debug(str(parsed), exc_info=True) + + async def _token_refresh(self, current_token): + """Refresh auth token. + + :param current_token: token (parsed) JWT + :type current_token: splitio.models.token.Token + """ + _LOGGER.debug("Next token refresh in " + str(self._get_time_period(current_token)) + " seconds") + await asyncio.sleep(self._get_time_period(current_token)) + await self._stop_current_conn() + self._running_task = asyncio.get_running_loop().create_task(self._trigger_connection_flow()) + + async def _get_auth_token(self): + """Get new auth token""" + try: + token = await self._auth_api.authenticate() + except APIException as e: + _LOGGER.error('error performing sse auth request.') + _LOGGER.debug('stack trace: ', exc_info=True) + await self._feedback_loop.put(Status.PUSH_RETRYABLE_ERROR) + raise AuthException(e) + + if token is not None and not token.push_enabled: + await self._feedback_loop.put(Status.PUSH_NONRETRYABLE_ERROR) + raise AuthException("Push is not enabled") + + await self._telemetry_runtime_producer.record_token_refreshes() + await self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.TOKEN_REFRESH, 1000 * token.exp, get_current_epoch_time_ms())) + _LOGGER.debug("auth token fetched. connecting to streaming.") + return token + + async def _trigger_connection_flow(self): + """Authenticate and start a connection.""" + self._status_tracker.reset() + + try: + token = await self._get_auth_token() + events_source = self._sse_client.start(token) + self._running = True + + first_event = await anext(events_source) + if first_event.data is not None: + await self._event_handler(first_event) + + _LOGGER.debug("connected to streaming, scheduling next refresh") + self._token_task = asyncio.get_running_loop().create_task(self._token_refresh(token)) + await self._handle_connection_ready() + await self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.CONNECTION_ESTABLISHED, 0, get_current_epoch_time_ms())) + + async for event in events_source: + await self._event_handler(event) + await self._handle_connection_end() # TODO(mredolatti): this is not tested + except AuthException as e: + _LOGGER.error("error getting auth token: " + str(e)) + _LOGGER.debug("trace: ", exc_info=True) + except StopAsyncIteration: # will enter here if there was an error + await self._feedback_loop.put(Status.PUSH_RETRYABLE_ERROR) + finally: + if self._token_task is not None: + self._token_task.cancel() + self._token_task = None + self._running = False + await self._processor.update_workers_status(False) + + async def _handle_message(self, event): + """ + Handle incoming update message. + + :param event: Incoming Update message + :type event: splitio.push.sse.parser.Update + """ + try: + handle = self._message_handlers[event.message_type] + except KeyError: + _LOGGER.error('no handler for message of type %s', event.message_type) + _LOGGER.debug(str(event), exc_info=True) + return + + await handle(event) + + async def _handle_update(self, event): + """ + Handle incoming update message. + + :param event: Incoming Update message + :type event: splitio.push.sse.parser.Update + """ + _LOGGER.debug('handling update event: %s', str(event)) + await self._processor.handle(event) + + async def _handle_control(self, event): + """ + Handle incoming control message. + + :param event: Incoming control message. + :type event: splitio.push.sse.parser.ControlMessage + """ + _LOGGER.debug('handling control event: %s', str(event)) + feedback = await self._status_tracker.handle_control_message(event) + if feedback is not None: + await self._feedback_loop.put(feedback) + + async def _handle_occupancy(self, event): + """ + Handle incoming notification message. + + :param event: Incoming occupancy message. + :type event: splitio.push.sse.parser.Occupancy + """ + _LOGGER.debug('handling occupancy event: %s', str(event)) + feedback = await self._status_tracker.handle_occupancy(event) + if feedback is not None: + await self._feedback_loop.put(feedback) + + async def _handle_error(self, event): + """ + Handle incoming error message. + + :param event: Incoming ably error + :type event: splitio.push.sse.parser.AblyError + """ + _LOGGER.debug('handling ably error event: %s', str(event)) + feedback = await self._status_tracker.handle_ably_error(event) + if feedback is not None: + await self._feedback_loop.put(feedback) + + async def _handle_connection_ready(self): + """Handle a successful connection to SSE.""" + await self._feedback_loop.put(Status.PUSH_SUBSYSTEM_UP) + _LOGGER.info('sse initial event received. enabling') + + async def _handle_connection_end(self): + """ + Handle a connection ending. + + If the connection shutdown was not requested, trigger a restart. + """ + feedback = await self._status_tracker.handle_disconnect() + if feedback is not None: + await self._feedback_loop.put(feedback) + + async def _stop_current_conn(self): + """Abort current streaming connection and stop it's associated workers.""" + _LOGGER.debug("Aborting SplitSSE tasks.") + await self._processor.update_workers_status(False) + self._status_tracker.notify_sse_shutdown_expected() + await self._sse_client.stop() + self._running_task.cancel() + await self._running_task + self._running_task = None + _LOGGER.debug("SplitSSE tasks are stopped") diff --git a/splitio/push/parser.py b/splitio/push/parser.py index 7d44096b..79b410e3 100644 --- a/splitio/push/parser.py +++ b/splitio/push/parser.py @@ -4,10 +4,9 @@ from enum import Enum from splitio.util.decorators import abstract_property -from splitio.util import utctime_ms +from splitio.util.time import utctime_ms from splitio.push.sse import SSE_EVENT_ERROR, SSE_EVENT_MESSAGE - class EventType(Enum): """Event type enumeration.""" @@ -29,6 +28,7 @@ class UpdateType(Enum): SPLIT_UPDATE = 'SPLIT_UPDATE' SPLIT_KILL = 'SPLIT_KILL' SEGMENT_UPDATE = 'SEGMENT_UPDATE' + RB_SEGMENT_UPDATE = 'RB_SEGMENT_UPDATE' class ControlType(Enum): @@ -277,7 +277,7 @@ def __str__(self): class BaseUpdate(BaseMessage, metaclass=abc.ABCMeta): - """Split data update notification.""" + """Feature flag data update notification.""" def __init__(self, channel, timestamp, change_number): """ @@ -324,11 +324,14 @@ def change_number(self): class SplitChangeUpdate(BaseUpdate): - """Split Change notification.""" + """Feature flag Change notification.""" - def __init__(self, channel, timestamp, change_number): + def __init__(self, channel, timestamp, change_number, previous_change_number, feature_flag_definition, compression): """Class constructor.""" BaseUpdate.__init__(self, channel, timestamp, change_number) + self._previous_change_number = previous_change_number + self._object_definition = feature_flag_definition + self._compression = compression @property def update_type(self): # pylint:disable=no-self-use @@ -340,18 +343,45 @@ def update_type(self): # pylint:disable=no-self-use """ return UpdateType.SPLIT_UPDATE + @property + def previous_change_number(self): # pylint:disable=no-self-use + """ + Return previous change number + :returns: The previous change number + :rtype: int + """ + return self._previous_change_number + + @property + def object_definition(self): # pylint:disable=no-self-use + """ + Return feature flag definition + :returns: The new feature flag definition + :rtype: str + """ + return self._object_definition + + @property + def compression(self): # pylint:disable=no-self-use + """ + Return previous compression type + :returns: The compression type + :rtype: int + """ + return self._compression + def __str__(self): """Return string representation.""" return "SplitChange - changeNumber=%d" % (self.change_number) class SplitKillUpdate(BaseUpdate): - """Split Kill notification.""" + """Feature flag Kill notification.""" - def __init__(self, channel, timestamp, change_number, split_name, default_treatment): # pylint:disable=too-many-arguments + def __init__(self, channel, timestamp, change_number, feature_flag_name, default_treatment): # pylint:disable=too-many-arguments """Class constructor.""" BaseUpdate.__init__(self, channel, timestamp, change_number) - self._split_name = split_name + self._feature_flag_name = feature_flag_name self._default_treatment = default_treatment @property @@ -365,14 +395,14 @@ def update_type(self): # pylint:disable=no-self-use return UpdateType.SPLIT_KILL @property - def split_name(self): + def feature_flag_name(self): """ - Return the name of the killed split. + Return the name of the killed feature flag. - :returns: name of the killed split + :returns: name of the killed feature flag :rtype: str """ - return self._split_name + return self._feature_flag_name @property def default_treatment(self): @@ -387,7 +417,7 @@ def default_treatment(self): def __str__(self): """Return string representation.""" return "SplitKill - changeNumber=%d, name=%s, defaultTreatment=%s" % \ - (self.change_number, self.split_name, self.default_treatment) + (self.change_number, self.feature_flag_name, self.default_treatment) class SegmentChangeUpdate(BaseUpdate): @@ -422,6 +452,56 @@ def __str__(self): """Return string representation.""" return "SegmentChange - changeNumber=%d, name=%s" % (self.change_number, self.segment_name) +class RBSChangeUpdate(BaseUpdate): + """rbs Change notification.""" + + def __init__(self, channel, timestamp, change_number, previous_change_number, rbs_definition, compression): + """Class constructor.""" + BaseUpdate.__init__(self, channel, timestamp, change_number) + self._previous_change_number = previous_change_number + self._object_definition = rbs_definition + self._compression = compression + + @property + def update_type(self): # pylint:disable=no-self-use + """ + Return the message type. + + :returns: The type of this parsed Update. + :rtype: UpdateType + """ + return UpdateType.RB_SEGMENT_UPDATE + + @property + def previous_change_number(self): # pylint:disable=no-self-use + """ + Return previous change number + :returns: The previous change number + :rtype: int + """ + return self._previous_change_number + + @property + def object_definition(self): # pylint:disable=no-self-use + """ + Return rbs definition + :returns: The new rbs definition + :rtype: str + """ + return self._object_definition + + @property + def compression(self): # pylint:disable=no-self-use + """ + Return previous compression type + :returns: The compression type + :rtype: int + """ + return self._compression + + def __str__(self): + """Return string representation.""" + return "RBSChange - changeNumber=%d" % (self.change_number) class ControlMessage(BaseMessage): """Control notification.""" @@ -471,13 +551,19 @@ def _parse_update(channel, timestamp, data): """ update_type = UpdateType(data['type']) change_number = data['changeNumber'] - if update_type == UpdateType.SPLIT_UPDATE: - return SplitChangeUpdate(channel, timestamp, change_number) - elif update_type == UpdateType.SPLIT_KILL: + if update_type == UpdateType.SPLIT_UPDATE and change_number is not None: + return SplitChangeUpdate(channel, timestamp, change_number, data.get('pcn'), data.get('d'), data.get('c')) + + if update_type == UpdateType.RB_SEGMENT_UPDATE and change_number is not None: + return RBSChangeUpdate(channel, timestamp, change_number, data.get('pcn'), data.get('d'), data.get('c')) + + elif update_type == UpdateType.SPLIT_KILL and change_number is not None: return SplitKillUpdate(channel, timestamp, change_number, data['splitName'], data['defaultTreatment']) + elif update_type == UpdateType.SEGMENT_UPDATE: return SegmentChangeUpdate(channel, timestamp, change_number, data['segmentName']) + raise EventParsingException('unrecognized event type %s' % update_type) @@ -493,15 +579,19 @@ def _parse_message(data): """ if not all(k in data for k in ['data', 'channel']): return None + channel = data['channel'] timestamp = data['timestamp'] parsed_data = json.loads(data['data']) if data.get('name') == TAG_OCCUPANCY: return OccupancyMessage(channel, timestamp, parsed_data['metrics']['publishers']) + elif parsed_data['type'] == 'CONTROL': return ControlMessage(channel, timestamp, parsed_data['controlType']) + elif parsed_data['type'] in UpdateType.__members__: return _parse_update(channel, timestamp, parsed_data) + raise EventParsingException('unrecognized message type %s' % parsed_data['type']) diff --git a/splitio/push/processor.py b/splitio/push/processor.py index 39329b6b..41d796c7 100644 --- a/splitio/push/processor.py +++ b/splitio/push/processor.py @@ -1,52 +1,68 @@ """Message processor & Notification manager keeper implementations.""" from queue import Queue +import abc from splitio.push.parser import UpdateType -from splitio.push.splitworker import SplitWorker -from splitio.push.segmentworker import SegmentWorker +from splitio.push.workers import SplitWorker, SplitWorkerAsync, SegmentWorker, SegmentWorkerAsync +from splitio.optional.loaders import asyncio +class MessageProcessorBase(object, metaclass=abc.ABCMeta): + """Message processor template.""" -class MessageProcessor(object): + @abc.abstractmethod + def update_workers_status(self, enabled): + """Enable/Disable push update workers.""" + + @abc.abstractmethod + def handle(self, event): + """Handle incoming update event.""" + + @abc.abstractmethod + def shutdown(self): + """Stop splits & segments workers.""" + +class MessageProcessor(MessageProcessorBase): """Message processor class.""" - def __init__(self, synchronizer): + def __init__(self, synchronizer, telemetry_runtime_producer): """ Class constructor. :param synchronizer: synchronizer component :type synchronizer: splitio.sync.synchronizer.Synchronizer """ - self._split_queue = Queue() + self._feature_flag_queue = Queue() self._segments_queue = Queue() self._synchronizer = synchronizer - self._split_worker = SplitWorker(synchronizer.synchronize_splits, self._split_queue) + self._feature_flag_worker = SplitWorker(synchronizer.synchronize_splits, synchronizer.synchronize_segment, self._feature_flag_queue, synchronizer.split_sync.feature_flag_storage, synchronizer.segment_storage, telemetry_runtime_producer, synchronizer.split_sync.rule_based_segment_storage) self._segments_worker = SegmentWorker(synchronizer.synchronize_segment, self._segments_queue) self._handlers = { - UpdateType.SPLIT_UPDATE: self._handle_split_update, - UpdateType.SPLIT_KILL: self._handle_split_kill, - UpdateType.SEGMENT_UPDATE: self._handle_segment_change + UpdateType.SPLIT_UPDATE: self._handle_feature_flag_update, + UpdateType.SPLIT_KILL: self._handle_feature_flag_kill, + UpdateType.SEGMENT_UPDATE: self._handle_segment_change, + UpdateType.RB_SEGMENT_UPDATE: self._handle_feature_flag_update } - def _handle_split_update(self, event): + def _handle_feature_flag_update(self, event): """ - Handle incoming split update notification. + Handle incoming feature_flag update notification. - :param event: Incoming split change event + :param event: Incoming feature_flag change event :type event: splitio.push.parser.SplitChangeUpdate """ - self._split_queue.put(event) + self._feature_flag_queue.put(event) - def _handle_split_kill(self, event): + def _handle_feature_flag_kill(self, event): """ - Handle incoming split kill notification. + Handle incoming feature flag kill notification. - :param event: Incoming split kill event + :param event: Incoming feature flag kill event :type event: splitio.push.parser.SplitKillUpdate """ - self._synchronizer.kill_split(event.split_name, event.default_treatment, + self._synchronizer.kill_split(event.feature_flag_name, event.default_treatment, event.change_number) - self._split_queue.put(event) + self._feature_flag_queue.put(event) def _handle_segment_change(self, event): """ @@ -65,10 +81,10 @@ def update_workers_status(self, enabled): :type enabled: bool """ if enabled: - self._split_worker.start() + self._feature_flag_worker.start() self._segments_worker.start() else: - self._split_worker.stop() + self._feature_flag_worker.stop() self._segments_worker.stop() def handle(self, event): @@ -86,6 +102,91 @@ def handle(self, event): handle(event) def shutdown(self): - """Stop splits & segments workers.""" - self._split_worker.stop() + """Stop feature flags & segments workers.""" + self._feature_flag_worker.stop() self._segments_worker.stop() + + +class MessageProcessorAsync(MessageProcessorBase): + """Message processor class.""" + + def __init__(self, synchronizer, telemetry_runtime_producer): + """ + Class constructor. + + :param synchronizer: synchronizer component + :type synchronizer: splitio.sync.synchronizer.Synchronizer + """ + self._feature_flag_queue = asyncio.Queue() + self._segments_queue = asyncio.Queue() + self._synchronizer = synchronizer + self._feature_flag_worker = SplitWorkerAsync(synchronizer.synchronize_splits, synchronizer.synchronize_segment, self._feature_flag_queue, synchronizer.split_sync.feature_flag_storage, synchronizer.segment_storage, telemetry_runtime_producer, synchronizer.split_sync.rule_based_segment_storage) + self._segments_worker = SegmentWorkerAsync(synchronizer.synchronize_segment, self._segments_queue) + self._handlers = { + UpdateType.SPLIT_UPDATE: self._handle_feature_flag_update, + UpdateType.SPLIT_KILL: self._handle_feature_flag_kill, + UpdateType.SEGMENT_UPDATE: self._handle_segment_change, + UpdateType.RB_SEGMENT_UPDATE: self._handle_feature_flag_update + } + + async def _handle_feature_flag_update(self, event): + """ + Handle incoming feature_flag update notification. + + :param event: Incoming feature_flag change event + :type event: splitio.push.parser.SplitChangeUpdate + """ + await self._feature_flag_queue.put(event) + + async def _handle_feature_flag_kill(self, event): + """ + Handle incoming feature_flag kill notification. + + :param event: Incoming feature_flag kill event + :type event: splitio.push.parser.SplitKillUpdate + """ + await self._synchronizer.kill_split(event.feature_flag_name, event.default_treatment, + event.change_number) + await self._feature_flag_queue.put(event) + + async def _handle_segment_change(self, event): + """ + Handle incoming segment update notification. + + :param event: Incoming segment change event + :type event: splitio.push.parser.Update + """ + await self._segments_queue.put(event) + + async def update_workers_status(self, enabled): + """ + Enable/Disable push update workers. + + :param enabled: if True, enable workers. If False, disable them. + :type enabled: bool + """ + if enabled: + self._feature_flag_worker.start() + self._segments_worker.start() + else: + await self._feature_flag_worker.stop() + await self._segments_worker.stop() + + async def handle(self, event): + """ + Handle incoming update event. + + :param event: incoming data update event. + :type event: splitio.push.BaseUpdate + """ + try: + handle = self._handlers[event.update_type] + except KeyError as exc: + raise Exception('no handler for notification type: %s' % event.update_type) from exc + + await handle(event) + + async def shutdown(self): + """Stop splits & segments workers.""" + await self._feature_flag_worker.stop() + await self._segments_worker.stop() diff --git a/splitio/push/segmentworker.py b/splitio/push/segmentworker.py deleted file mode 100644 index 3861c602..00000000 --- a/splitio/push/segmentworker.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Segment changes processing worker.""" -import logging -import threading - - -_LOGGER = logging.getLogger(__name__) - - -class SegmentWorker(object): - """Segment Worker for processing updates.""" - - _centinel = object() - - def __init__(self, synchronize_segment, segment_queue): - """ - Class constructor. - - :param synchronize_segment: handler to perform segment synchronization on incoming event - :type synchronize_segment: function - - :param segment_queue: queue with segment updates notifications - :type segment_queue: queue - """ - self._segment_queue = segment_queue - self._handler = synchronize_segment - self._running = False - self._worker = None - - def is_running(self): - """Return whether the working is running.""" - return self._running - - def _run(self): - """Run worker handler.""" - while self.is_running(): - event = self._segment_queue.get() - if not self.is_running(): - break - if event == self._centinel: - continue - _LOGGER.debug('Processing segment_update: %s, change_number: %d', - event.segment_name, event.change_number) - try: - self._handler(event.segment_name, event.change_number) - except Exception: - _LOGGER.error('Exception raised in segment synchronization') - _LOGGER.debug('Exception information: ', exc_info=True) - - def start(self): - """Start worker.""" - if self.is_running(): - _LOGGER.debug('Worker is already running') - return - self._running = True - - _LOGGER.debug('Starting Segment Worker') - self._worker = threading.Thread(target=self._run, name='PushSegmentWorker') - self._worker.setDaemon(True) - self._worker.start() - - def stop(self): - """Stop worker.""" - _LOGGER.debug('Stopping Segment Worker') - if not self.is_running(): - _LOGGER.debug('Worker is not running. Ignoring.') - return - self._running = False - self._segment_queue.put(self._centinel) diff --git a/splitio/push/splitsse.py b/splitio/push/splitsse.py index f16a317f..788648d4 100644 --- a/splitio/push/splitsse.py +++ b/splitio/push/splitsse.py @@ -2,16 +2,21 @@ import logging import threading from enum import Enum -from splitio.push.sse import SSEClient, SSE_EVENT_ERROR +import abc +import sys + +from splitio.push.sse import SSEClient, SSEClientAsync, SSE_EVENT_ERROR from splitio.util.threadutil import EventGroup -from splitio.api.commons import headers_from_metadata +from splitio.api import headers_from_metadata +from splitio.optional.loaders import asyncio +if sys.version_info.major == 3 and sys.version_info.minor < 10: + from splitio.optional.loaders import _anext as anext _LOGGER = logging.getLogger(__name__) - -class SplitSSEClient(object): # pylint: disable=too-many-instance-attributes - """Split streaming endpoint SSE client.""" +class SplitSSEClientBase(object, metaclass=abc.ABCMeta): + """Split streaming endpoint SSE base client.""" KEEPALIVE_TIMEOUT = 70 @@ -21,6 +26,59 @@ class _Status(Enum): ERRORED = 2 CONNECTED = 3 + def __init__(self, base_url): + """ + Construct a split sse client. + + :param base_url: scheme + :// + host + :type base_url: str + """ + self._base_url = base_url + + @staticmethod + def _format_channels(channels): + """ + Format channels into a list from the raw object retrieved in the token. + + :param channels: object as extracted from the JWT capabilities. + :type channels: dict[str,list[str]] + + :returns: channels as a list of strings. + :rtype: list[str] + """ + regular = [k for (k, v) in channels.items() if v == ['subscribe']] + occupancy = ['[?occupancy=metrics.publishers]' + k + for (k, v) in channels.items() + if 'channel-metadata:publishers' in v] + return regular + occupancy + + def _build_url(self, token): + """ + Build the url to connect to and return it as a string. + + :param token: (parsed) JWT + :type token: splitio.models.token.Token + + :returns: true if the connection was successful. False otherwise. + :rtype: bool + """ + return '{base}/event-stream?v=1.1&accessToken={token}&channels={channels}'.format( + base=self._base_url, + token=token.token, + channels=','.join(self._format_channels(token.channels))) + + @abc.abstractmethod + def start(self, token): + """Open a connection to start listening for events.""" + + @abc.abstractmethod + def stop(self, blocking=False, timeout=None): + """Abort the ongoing connection.""" + + +class SplitSSEClient(SplitSSEClientBase): # pylint: disable=too-many-instance-attributes + """Split streaming endpoint SSE client.""" + def __init__(self, event_callback, sdk_metadata, first_event_callback=None, connection_closed_callback=None, client_key=None, base_url='https://streaming.split.io'): @@ -45,11 +103,11 @@ def __init__(self, event_callback, sdk_metadata, first_event_callback=None, :param client_key: client key. :type client_key: str """ + SplitSSEClientBase.__init__(self, base_url) self._client = SSEClient(self._raw_event_handler) self._callback = event_callback self._on_connected = first_event_callback self._on_disconnected = connection_closed_callback - self._base_url = base_url self._status = SplitSSEClient._Status.IDLE self._sse_first_event = None self._sse_connection_closed = None @@ -72,38 +130,6 @@ def _raw_event_handler(self, event): if event.data is not None: self._callback(event) - @staticmethod - def _format_channels(channels): - """ - Format channels into a list from the raw object retrieved in the token. - - :param channels: object as extracted from the JWT capabilities. - :type channels: dict[str,list[str]] - - :returns: channels as a list of strings. - :rtype: list[str] - """ - regular = [k for (k, v) in channels.items() if v == ['subscribe']] - occupancy = ['[?occupancy=metrics.publishers]' + k - for (k, v) in channels.items() - if 'channel-metadata:publishers' in v] - return regular + occupancy - - def _build_url(self, token): - """ - Build the url to connect to and return it as a string. - - :param token: (parsed) JWT - :type token: splitio.models.token.Token - - :returns: true if the connection was successful. False otherwise. - :rtype: bool - """ - return '{base}/event-stream?v=1.1&accessToken={token}&channels={channels}'.format( - base=self._base_url, - token=token.token, - channels=','.join(self._format_channels(token.channels))) - def start(self, token): """ Open a connection to start listening for events. @@ -134,8 +160,7 @@ def connect(url): self._on_disconnected() url = self._build_url(token) - task = threading.Thread(target=connect, name='SSEConnection', args=(url,)) - task.setDaemon(True) + task = threading.Thread(target=connect, name='SSEConnection', args=(url,), daemon=True) task.start() event_group.wait() return self._status == SplitSSEClient._Status.CONNECTED @@ -149,3 +174,82 @@ def stop(self, blocking=False, timeout=None): self._client.shutdown() if blocking: self._sse_connection_closed.wait(timeout) + +class SplitSSEClientAsync(SplitSSEClientBase): # pylint: disable=too-many-instance-attributes + """Split streaming endpoint SSE client.""" + + def __init__(self, sdk_metadata, client_key=None, base_url='https://streaming.split.io'): + """ + Construct a split sse client. + + :param sdk_metadata: SDK version & machine name & IP. + :type sdk_metadata: splitio.client.util.SdkMetadata + + :param client_key: client key. + :type client_key: str + + :param base_url: scheme + :// + host + :type base_url: str + """ + SplitSSEClientBase.__init__(self, base_url) + self.status = SplitSSEClient._Status.IDLE + self._metadata = headers_from_metadata(sdk_metadata, client_key) + self._client = SSEClientAsync(self.KEEPALIVE_TIMEOUT) + self._event_source = None + self._event_source_ended = asyncio.Event() + + async def start(self, token): + """ + Open a connection to start listening for events. + + :param token: (parsed) JWT + :type token: splitio.models.token.Token + + :returns: yield events received from SSEClientAsync object + :rtype: SSEEvent + """ + _LOGGER.debug(self.status) + if self.status != SplitSSEClient._Status.IDLE: + raise Exception('SseClient already started.') + + self.status = SplitSSEClient._Status.CONNECTING + url = self._build_url(token) + try: + self._event_source_ended.clear() + self._event_source = self._client.start(url, extra_headers=self._metadata) + first_event = await anext(self._event_source) + if first_event.event == SSE_EVENT_ERROR: + return + + yield first_event + self.status = SplitSSEClient._Status.CONNECTED + _LOGGER.debug("Split SSE client started") + async for event in self._event_source: + if event.data is not None: + yield event + except Exception: # pylint:disable=broad-except + _LOGGER.error('SplitSSE Client Exception') + _LOGGER.debug('stack trace: ', exc_info=True) + finally: + self.status = SplitSSEClient._Status.IDLE + _LOGGER.debug('Split sse connection ended.') + self._event_source_ended.set() + + async def stop(self): + """Abort the ongoing connection.""" + _LOGGER.debug("stopping SplitSSE Client") + if self.status == SplitSSEClient._Status.IDLE: + _LOGGER.warning('sse already closed. ignoring') + return + + await self._client.shutdown() +# catching exception to avoid task hanging + try: + await self._event_source_ended.wait() + except asyncio.CancelledError as e: + _LOGGER.debug("Exception waiting for event source ended") + _LOGGER.debug('stack trace: ', exc_info=True) + pass + + async def close_sse_http_client(self): + await self._client.close_session() diff --git a/splitio/push/splitworker.py b/splitio/push/splitworker.py deleted file mode 100644 index 4eb8ca99..00000000 --- a/splitio/push/splitworker.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Split changes processing worker.""" -import logging -import threading - - -_LOGGER = logging.getLogger(__name__) - - -class SplitWorker(object): - """Split Worker for processing updates.""" - - _centinel = object() - - def __init__(self, synchronize_split, split_queue): - """ - Class constructor. - - :param synchronize_split: handler to perform split synchronization on incoming event - :type synchronize_split: callable - - :param split_queue: queue with split updates notifications - :type split_queue: queue - """ - self._split_queue = split_queue - self._handler = synchronize_split - self._running = False - self._worker = None - - def is_running(self): - """Return whether the working is running.""" - return self._running - - def _run(self): - """Run worker handler.""" - while self.is_running(): - event = self._split_queue.get() - if not self.is_running(): - break - if event == self._centinel: - continue - _LOGGER.debug('Processing split_update %d', event.change_number) - try: - self._handler(event.change_number) - except Exception: # pylint: disable=broad-except - _LOGGER.error('Exception raised in split synchronization') - _LOGGER.debug('Exception information: ', exc_info=True) - - def start(self): - """Start worker.""" - if self.is_running(): - _LOGGER.debug('Worker is already running') - return - self._running = True - - _LOGGER.debug('Starting Split Worker') - self._worker = threading.Thread(target=self._run, name='PushSplitWorker') - self._worker.setDaemon(True) - self._worker.start() - - def stop(self): - """Stop worker.""" - _LOGGER.debug('Stopping Split Worker') - if not self.is_running(): - _LOGGER.debug('Worker is not running') - return - self._running = False - self._split_queue.put(self._centinel) diff --git a/splitio/push/sse.py b/splitio/push/sse.py index 1cbf8a5c..8cde7f98 100644 --- a/splitio/push/sse.py +++ b/splitio/push/sse.py @@ -5,20 +5,21 @@ from http.client import HTTPConnection, HTTPSConnection from urllib.parse import urlparse +from splitio.optional.loaders import asyncio, aiohttp _LOGGER = logging.getLogger(__name__) - SSE_EVENT_ERROR = 'error' SSE_EVENT_MESSAGE = 'message' - +_DEFAULT_HEADERS = {'accept': 'text/event-stream'} +_EVENT_SEPARATORS = set([b'\n', b'\r\n']) +_DEFAULT_SOCKET_READ_TIMEOUT = 70 SSEEvent = namedtuple('SSEEvent', ['event_id', 'event', 'retry', 'data']) __ENDING_CHARS = set(['\n', '']) - class EventBuilder(object): """Event builder class.""" @@ -46,13 +47,9 @@ def build(self): return SSEEvent(self._lines.get('id'), self._lines.get('event'), self._lines.get('retry'), self._lines.get('data')) - class SSEClient(object): """SSE Client implementation.""" - _DEFAULT_HEADERS = {'accept': 'text/event-stream'} - _EVENT_SEPARATORS = set([b'\n', b'\r\n']) - def __init__(self, callback): """ Construct an SSE client. @@ -81,7 +78,7 @@ def _read_events(self): elif line.startswith(b':'): # comment. Skip _LOGGER.debug("skipping sse comment") continue - elif line in self._EVENT_SEPARATORS: + elif line in _EVENT_SEPARATORS: event = event_builder.build() _LOGGER.debug("dispatching event: %s", event) self._event_callback(event) @@ -117,9 +114,7 @@ def start(self, url, extra_headers=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT) raise RuntimeError('Client already started.') self._shutdown_requested = False - url = urlparse(url) - headers = self._DEFAULT_HEADERS.copy() - headers.update(extra_headers if extra_headers is not None else {}) + url, headers = urlparse(url), get_headers(extra_headers) self._conn = (HTTPSConnection(url.hostname, url.port, timeout=timeout) if url.scheme == 'https' else HTTPConnection(url.hostname, port=url.port, timeout=timeout)) @@ -139,3 +134,100 @@ def shutdown(self): self._shutdown_requested = True self._conn.sock.shutdown(socket.SHUT_RDWR) + + +class SSEClientAsync(object): + """SSE Client implementation.""" + + def __init__(self, socket_read_timeout=_DEFAULT_SOCKET_READ_TIMEOUT): + """ + Construct an SSE client. + + :param url: url to connect to + :type url: str + + :param extra_headers: additional headers + :type extra_headers: dict[str, str] + + :param timeout: connection & read timeout + :type timeout: float + """ + self._socket_read_timeout = socket_read_timeout + socket_read_timeout * .3 + self._response = None + self._done = asyncio.Event() + client_timeout = aiohttp.ClientTimeout(total=0, sock_read=self._socket_read_timeout) + self._sess = aiohttp.ClientSession(timeout=client_timeout) + + async def start(self, url, extra_headers=None): # pylint:disable=protected-access + """ + Connect and start listening for events. + + :returns: yield event when received + :rtype: SSEEvent + """ + _LOGGER.debug("Async SSEClient Started") + if self._response is not None: + raise RuntimeError('Client already started.') + + self._done.clear() + try: + async with self._sess.get(url, headers=get_headers(extra_headers)) as response: + self._response = response + event_builder = EventBuilder() + async for line in response.content: + if line.startswith(b':'): + _LOGGER.debug("skipping emtpy line / comment") + continue + elif line in _EVENT_SEPARATORS: + _LOGGER.debug("dispatching event: %s", event_builder.build()) + yield event_builder.build() + event_builder = EventBuilder() + else: + event_builder.process_line(line) + + except Exception as exc: # pylint:disable=broad-except + if self._is_conn_closed_error(exc): + _LOGGER.debug('sse connection ended.') + return + + _LOGGER.error('http client is throwing exceptions') + _LOGGER.error('stack trace: ', exc_info=True) + + finally: + self._response = None + self._done.set() + + async def shutdown(self): + """Close connection""" + if self._response: + self._response.close() + # catching exception to avoid task hanging if a canceled exception occurred + try: + await self._done.wait() + except asyncio.CancelledError: + _LOGGER.debug("Exception waiting for SSE connection to end") + _LOGGER.debug('stack trace: ', exc_info=True) + pass + + @staticmethod + def _is_conn_closed_error(exc): + """Check if the ReadError is caused by the connection being closed.""" + return isinstance(exc, aiohttp.ClientConnectionError) and str(exc) == "Connection closed" + + async def close_session(self): + if not self._sess.closed: + await self._sess.close() + +def get_headers(extra=None): + """ + Return default headers with added custom ones if specified. + + :param extra: additional headers + :type extra: dict[str, str] + + :returns: processed Headers + :rtype: dict + """ + headers = _DEFAULT_HEADERS.copy() + headers.update(extra if extra is not None else {}) + return headers diff --git a/splitio/push/status_tracker.py b/splitio/push/status_tracker.py index 6acd5d95..ec11cb48 100644 --- a/splitio/push/status_tracker.py +++ b/splitio/push/status_tracker.py @@ -1,8 +1,10 @@ """NotificationManagerKeeper implementation.""" from enum import Enum import logging -from splitio.push.parser import ControlType +from splitio.push.parser import ControlType +from splitio.util.time import get_current_epoch_time_ms +from splitio.models.telemetry import StreamingEventTypes, SSEConnectionError, SSEStreamingStatus _LOGGER = logging.getLogger(__name__) @@ -30,10 +32,10 @@ def reset(self): self.occupancy = -1 -class PushStatusTracker(object): +class PushStatusTrackerBase(object): """Tracks status of notification manager/publishers.""" - def __init__(self): + def __init__(self, telemetry_runtime_producer): """Class constructor.""" self._publishers = {} self._last_control_message = None @@ -41,6 +43,7 @@ def __init__(self): self._timestamps = LastEventTimestamps() self._shutdown_expected = None self.reset() # Set proper initial values + self._telemetry_runtime_producer = telemetry_runtime_producer def reset(self): """ @@ -54,6 +57,66 @@ def reset(self): self._timestamps.reset() self._shutdown_expected = False + def notify_sse_shutdown_expected(self): + """Let the status tracker know that an sse shutdown has been requested.""" + self._shutdown_expected = True + + def _propagate_status(self, status): + """ + Store and propagates a new status. + + :param status: Status to propagate. + :type status: Status + + :returns: Status to propagate + :rtype: status + """ + self._last_status_propagated = status + return status + + def _occupancy_ok(self): + """ + Return whether we have enough publishers. + + :returns: True if publisher count is enough. False otherwise + :rtype: bool + """ + return any(count > 0 for (chan, count) in self._publishers.items()) + + def _get_event_type_occupancy(self, event): + return StreamingEventTypes.OCCUPANCY_PRI if event.channel[-3:] == 'pri' else StreamingEventTypes.OCCUPANCY_SEC + + def _get_next_status(self): + """ + Return the next status to propagate based on the last status. + + :returns: Next status and Streaming status for telemetry event. + :rtype: Tuple(splitio.push.status_tracker.Status, splitio.models.telemetry.SSEStreamingStatus) + """ + if self._last_status_propagated == Status.PUSH_SUBSYSTEM_UP: + if not self._occupancy_ok() \ + or self._last_control_message == ControlType.STREAMING_PAUSED: + return self._propagate_status(Status.PUSH_SUBSYSTEM_DOWN), SSEStreamingStatus.PAUSED.value + + if self._last_control_message == ControlType.STREAMING_DISABLED: + return self._propagate_status(Status.PUSH_NONRETRYABLE_ERROR), SSEStreamingStatus.DISABLED.value + + if self._last_status_propagated == Status.PUSH_SUBSYSTEM_DOWN: + if self._occupancy_ok() and self._last_control_message == ControlType.STREAMING_ENABLED: + return self._propagate_status(Status.PUSH_SUBSYSTEM_UP), SSEStreamingStatus.ENABLED.value + + if self._last_control_message == ControlType.STREAMING_DISABLED: + return self._propagate_status(Status.PUSH_NONRETRYABLE_ERROR), SSEStreamingStatus.DISABLED.value + + return None, None + +class PushStatusTracker(PushStatusTrackerBase): + """Tracks status of notification manager/publishers.""" + + def __init__(self, telemetry_runtime_producer): + """Class constructor.""" + PushStatusTrackerBase.__init__(self, telemetry_runtime_producer) + def handle_occupancy(self, event): """ Handle an incoming occupancy event. @@ -73,11 +136,17 @@ def handle_occupancy(self, event): return None if self._timestamps.occupancy > event.timestamp: - _LOGGER.info('receved an old occupancy message. ignoring.') + _LOGGER.info('received an old occupancy message. ignoring.') return None + self._timestamps.occupancy = event.timestamp self._publishers[event.channel] = event.publishers + self._telemetry_runtime_producer.record_streaming_event(( + self._get_event_type_occupancy(event), + len(self._publishers), + event.timestamp + )) return self._update_status() def handle_control_message(self, event): @@ -94,6 +163,7 @@ def handle_control_message(self, event): if self._timestamps.control > event.timestamp: _LOGGER.info('receved an old control message. ignoring.') return None + self._timestamps.control = event.timestamp self._last_control_message = event.control_type @@ -122,6 +192,7 @@ def handle_ably_error(self, event): # 2. RETRYABLE_ERROR is propagated and the connection is closed on the clint side. # By doing this we guarantee that only one error will be propagated self.notify_sse_shutdown_expected() + self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.ABLY_ERROR, event.code, event.timestamp)) if event.is_retryable(): _LOGGER.info('received retryable error message. ' @@ -131,10 +202,6 @@ def handle_ably_error(self, event): _LOGGER.info('received non-retryable sse error message. Disabling streaming.') return self._propagate_status(Status.PUSH_NONRETRYABLE_ERROR) - def notify_sse_shutdown_expected(self): - """Let the status tracker know that an sse shutdown has been requested.""" - self._shutdown_expected = True - def _update_status(self): """ Evaluate the current/previous status and emit a new status message if appropriate. @@ -142,20 +209,10 @@ def _update_status(self): :returns: A new status if required. None otherwise :rtype: Optional[Status] """ - if self._last_status_propagated == Status.PUSH_SUBSYSTEM_UP: - if not self._occupancy_ok() \ - or self._last_control_message == ControlType.STREAMING_PAUSED: - return self._propagate_status(Status.PUSH_SUBSYSTEM_DOWN) - - if self._last_control_message == ControlType.STREAMING_DISABLED: - return self._propagate_status(Status.PUSH_NONRETRYABLE_ERROR) - - if self._last_status_propagated == Status.PUSH_SUBSYSTEM_DOWN: - if self._occupancy_ok() and self._last_control_message == ControlType.STREAMING_ENABLED: - return self._propagate_status(Status.PUSH_SUBSYSTEM_UP) - - if self._last_control_message == ControlType.STREAMING_DISABLED: - return self._propagate_status(Status.PUSH_NONRETRYABLE_ERROR) + next_status, telemetry_event_type = self._get_next_status() + if next_status is not None: + self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.STREAMING_STATUS, telemetry_event_type, get_current_epoch_time_ms())) + return next_status return None @@ -171,27 +228,132 @@ def handle_disconnect(self): :rtype: Optional[Status] """ if not self._shutdown_expected: + self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.SSE_CONNECTION_ERROR, SSEConnectionError.NON_REQUESTED.value, get_current_epoch_time_ms())) return self._propagate_status(Status.PUSH_RETRYABLE_ERROR) + + self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.SSE_CONNECTION_ERROR, SSEConnectionError.REQUESTED.value, get_current_epoch_time_ms())) return None - def _propagate_status(self, status): +class PushStatusTrackerAsync(PushStatusTrackerBase): + """Tracks status of notification manager/publishers.""" + + def __init__(self, telemetry_runtime_producer): + """Class constructor.""" + PushStatusTrackerBase.__init__(self, telemetry_runtime_producer) + + async def handle_occupancy(self, event): """ - Store and propagates a new status. + Handle an incoming occupancy event. - :param status: Status to propagate. - :type status: Status + :param event: incoming occupancy event. + :type event: splitio.push.sse.parser.Occupancy - :returns: Status to propagate - :rtype: status + :returns: A new status if required. None otherwise + :rtype: Optional[Status] """ - self._last_status_propagated = status - return status + if self._shutdown_expected: # we don't care about occupancy if a disconnection is expected + return None - def _occupancy_ok(self): + if event.channel not in self._publishers: + _LOGGER.info("received occupancy message from an unknown channel `%s`. Ignoring", + event.channel) + return None + + if self._timestamps.occupancy > event.timestamp: + _LOGGER.info('received an old occupancy message. ignoring.') + return None + + self._timestamps.occupancy = event.timestamp + + self._publishers[event.channel] = event.publishers + await self._telemetry_runtime_producer.record_streaming_event(( + self._get_event_type_occupancy(event), + len(self._publishers), + event.timestamp + )) + return await self._update_status() + + async def handle_control_message(self, event): """ - Return whether we have enough publishers. + Handle an incoming Control event. - :returns: True if publisher count is enough. False otherwise - :rtype: bool + :param event: Incoming control event + :type event: splitio.push.parser.ControlMessage """ - return any(count > 0 for (chan, count) in self._publishers.items()) + # we don't care about control messages if a disconnection is expected + if self._shutdown_expected: + return None + + if self._timestamps.control > event.timestamp: + _LOGGER.info('receved an old control message. ignoring.') + return None + + self._timestamps.control = event.timestamp + + self._last_control_message = event.control_type + return await self._update_status() + + async def handle_ably_error(self, event): + """ + Handle an ably-specific error. + + :param event: parsed ably error + :type event: splitio.push.parser.AblyError + + :returns: A new status if required. None otherwise + :rtype: Optional[Status] + """ + if self._shutdown_expected: # we don't care about an incoming error if a shutdown is expected + return None + + _LOGGER.debug('handling ably error event: %s', str(event)) + if event.should_be_ignored(): + _LOGGER.debug('ignoring sse error message: %s', event) + return None + + # Indicate that the connection will eventually end. 2 possibilities: + # 1. The server closes the connection after sending the error + # 2. RETRYABLE_ERROR is propagated and the connection is closed on the clint side. + # By doing this we guarantee that only one error will be propagated + self.notify_sse_shutdown_expected() + await self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.ABLY_ERROR, event.code, event.timestamp)) + + if event.is_retryable(): + _LOGGER.info('received retryable error message. ' + 'Restarting the whole flow with backoff.') + return self._propagate_status(Status.PUSH_RETRYABLE_ERROR) + + _LOGGER.info('received non-retryable sse error message. Disabling streaming.') + return self._propagate_status(Status.PUSH_NONRETRYABLE_ERROR) + + async def _update_status(self): + """ + Evaluate the current/previous status and emit a new status message if appropriate. + + :returns: A new status if required. None otherwise + :rtype: Optional[Status] + """ + next_status, telemetry_event_type = self._get_next_status() + if next_status is not None: + await self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.STREAMING_STATUS, telemetry_event_type, get_current_epoch_time_ms())) + return next_status + + return None + + async def handle_disconnect(self): + """ + Handle non-requested SSE disconnection. + + It should properly handle: + - connection reset/timeout + - disconnection after an ably error + + :returns: A new status if required. None otherwise + :rtype: Optional[Status] + """ + if not self._shutdown_expected: + await self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.SSE_CONNECTION_ERROR, SSEConnectionError.NON_REQUESTED.value, get_current_epoch_time_ms())) + return self._propagate_status(Status.PUSH_RETRYABLE_ERROR) + + await self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.SSE_CONNECTION_ERROR, SSEConnectionError.REQUESTED.value, get_current_epoch_time_ms())) + return None diff --git a/splitio/push/workers.py b/splitio/push/workers.py new file mode 100644 index 00000000..e0dd8369 --- /dev/null +++ b/splitio/push/workers.py @@ -0,0 +1,445 @@ +"""Segment changes processing worker.""" +import logging +import threading +import abc +import gzip +import zlib +import base64 +import json +from enum import Enum + +from splitio.models.splits import from_raw +from splitio.models.rule_based_segments import from_raw as rbs_from_raw +from splitio.models.telemetry import UpdateFromSSE +from splitio.push import SplitStorageException +from splitio.push.parser import UpdateType +from splitio.optional.loaders import asyncio +from splitio.util.storage_helper import update_feature_flag_storage, update_feature_flag_storage_async, \ + update_rule_based_segment_storage, update_rule_based_segment_storage_async + +_LOGGER = logging.getLogger(__name__) + +class CompressionMode(Enum): + """Compression modes """ + + NO_COMPRESSION = 0 + GZIP_COMPRESSION = 1 + ZLIB_COMPRESSION = 2 + +_compression_handlers = { + CompressionMode.NO_COMPRESSION: lambda event: base64.b64decode(event.object_definition), + CompressionMode.GZIP_COMPRESSION: lambda event: gzip.decompress(base64.b64decode(event.object_definition)).decode('utf-8'), + CompressionMode.ZLIB_COMPRESSION: lambda event: zlib.decompress(base64.b64decode(event.object_definition)).decode('utf-8'), +} + +class WorkerBase(object, metaclass=abc.ABCMeta): + """Worker template.""" + + _fetching_segment = "Fetching new segment {segment_name}" + + @abc.abstractmethod + def is_running(self): + """Return whether the working is running.""" + + @abc.abstractmethod + def start(self): + """Start worker.""" + + @abc.abstractmethod + def stop(self): + """Stop worker.""" + + def _get_object_definition(self, event): + """return feature flag or rule based segment definition in event.""" + cm = CompressionMode(event.compression) # will throw if the number is not defined in compression mode + return _compression_handlers[cm](event) + + def _get_referenced_rbs(self, feature_flag): + referenced_rbs = set() + for condition in feature_flag.conditions: + for matcher in condition.matchers: + raw_matcher = matcher.to_json() + if raw_matcher['matcherType'] == 'IN_RULE_BASED_SEGMENT': + referenced_rbs.add(raw_matcher['userDefinedSegmentMatcherData']['segmentName']) + return referenced_rbs + +class SegmentWorker(WorkerBase): + """Segment Worker for processing updates.""" + + _centinel = object() + + def __init__(self, synchronize_segment, segment_queue): + """ + Class constructor. + + :param synchronize_segment: handler to perform segment synchronization on incoming event + :type synchronize_segment: function + + :param segment_queue: queue with segment updates notifications + :type segment_queue: queue + """ + self._segment_queue = segment_queue + self._handler = synchronize_segment + self._running = False + self._worker = None + + def is_running(self): + """Return whether the working is running.""" + return self._running + + def _run(self): + """Run worker handler.""" + while self.is_running(): + event = self._segment_queue.get() + if not self.is_running(): + break + if event == self._centinel: + continue + _LOGGER.debug('Processing segment_update: %s, change_number: %d', + event.segment_name, event.change_number) + try: + self._handler(event.segment_name, event.change_number) + except Exception: + _LOGGER.error('Exception raised in segment synchronization') + _LOGGER.debug('Exception information: ', exc_info=True) + + def start(self): + """Start worker.""" + if self.is_running(): + _LOGGER.debug('Worker is already running') + return + self._running = True + + _LOGGER.debug('Starting Segment Worker') + self._worker = threading.Thread(target=self._run, name='PushSegmentWorker', daemon=True) + self._worker.start() + + def stop(self): + """Stop worker.""" + _LOGGER.debug('Stopping Segment Worker') + if not self.is_running(): + _LOGGER.debug('Worker is not running. Ignoring.') + return + self._running = False + self._segment_queue.put(self._centinel) + +class SegmentWorkerAsync(WorkerBase): + """Segment Worker for processing updates.""" + + _centinel = object() + + def __init__(self, synchronize_segment, segment_queue): + """ + Class constructor. + + :param synchronize_segment: handler to perform segment synchronization on incoming event + :type synchronize_segment: function + + :param segment_queue: queue with segment updates notifications + :type segment_queue: asyncio.Queue + """ + self._segment_queue = segment_queue + self._handler = synchronize_segment + self._running = False + + def is_running(self): + """Return whether the working is running.""" + return self._running + + async def _run(self): + """Run worker handler.""" + while self.is_running(): + event = await self._segment_queue.get() + if not self.is_running(): + break + if event == self._centinel: + continue + _LOGGER.debug('Processing segment_update: %s, change_number: %d', + event.segment_name, event.change_number) + try: + await self._handler(event.segment_name, event.change_number) + except Exception: + _LOGGER.error('Exception raised in segment synchronization') + _LOGGER.debug('Exception information: ', exc_info=True) + + def start(self): + """Start worker.""" + if self.is_running(): + _LOGGER.debug('Worker is already running') + return + self._running = True + + _LOGGER.debug('Starting Segment Worker') + asyncio.get_running_loop().create_task(self._run()) + + async def stop(self): + """Stop worker.""" + _LOGGER.debug('Stopping Segment Worker') + if not self.is_running(): + _LOGGER.debug('Worker is not running. Ignoring.') + return + self._running = False + await self._segment_queue.put(self._centinel) + +class SplitWorker(WorkerBase): + """Feature Flag Worker for processing updates.""" + + _centinel = object() + + def __init__(self, synchronize_feature_flag, synchronize_segment, feature_flag_queue, feature_flag_storage, segment_storage, telemetry_runtime_producer, rule_based_segment_storage): + """ + Class constructor. + + :param synchronize_feature_flag: handler to perform feature flag synchronization on incoming event + :type synchronize_feature_flag: callable + :param synchronize_segment: handler to perform segment synchronization on incoming event + :type synchronize_segment: function + :param feature_flag_queue: queue with feature flag updates notifications + :type feature_flag_queue: queue + :param feature_flag_storage: feature flag storage instance + :type feature_flag_storage: splitio.storage.inmemory.InMemorySplitStorage + :param segment_storage: segment storage instance + :type segment_storage: splitio.storage.inmemory.InMemorySegmentStorage + :param telemetry_runtime_producer: Telemetry runtime producer instance + :type telemetry_runtime_producer: splitio.engine.telemetry.TelemetryRuntimeProducer + :param rule_based_segment_storage: Rule based segment Storage. + :type rule_based_segment_storage: splitio.storage.InMemoryRuleBasedStorage + """ + self._feature_flag_queue = feature_flag_queue + self._handler = synchronize_feature_flag + self._segment_handler = synchronize_segment + self._running = False + self._worker = None + self._feature_flag_storage = feature_flag_storage + self._segment_storage = segment_storage + self._telemetry_runtime_producer = telemetry_runtime_producer + self._rule_based_segment_storage = rule_based_segment_storage + + def is_running(self): + """Return whether the working is running.""" + return self._running + + def _apply_iff_if_needed(self, event): + if not self._check_instant_ff_update(event): + return False + try: + if event.update_type == UpdateType.SPLIT_UPDATE: + new_feature_flag = from_raw(json.loads(self._get_object_definition(event))) + segment_list = update_feature_flag_storage(self._feature_flag_storage, [new_feature_flag], event.change_number) + for segment_name in segment_list: + if self._segment_storage.get(segment_name) is None: + _LOGGER.debug(self._fetching_segment.format(segment_name=segment_name)) + self._segment_handler(segment_name, event.change_number) + + referenced_rbs = self._get_referenced_rbs(new_feature_flag) + self._fetch_rbs_segment_if_needed(referenced_rbs, event) + self._telemetry_runtime_producer.record_update_from_sse(UpdateFromSSE.SPLIT_UPDATE) + else: + new_rbs = rbs_from_raw(json.loads(self._get_object_definition(event))) + segment_list = update_rule_based_segment_storage(self._rule_based_segment_storage, [new_rbs], event.change_number) + for segment_name in segment_list: + if self._segment_storage.get(segment_name) is None: + _LOGGER.debug(self._fetching_segment.format(segment_name=segment_name)) + self._segment_handler(segment_name, event.change_number) + self._telemetry_runtime_producer.record_update_from_sse(UpdateFromSSE.RBS_UPDATE) + return True + + except Exception as e: + raise SplitStorageException(e) + + def _fetch_rbs_segment_if_needed(self, referenced_rbs, event): + if len(referenced_rbs) > 0 and not self._rule_based_segment_storage.contains(referenced_rbs): + _LOGGER.debug('Fetching new rule based segment(s) %s', referenced_rbs) + self._handler(None, event.change_number) + + def _check_instant_ff_update(self, event): + if event.update_type == UpdateType.SPLIT_UPDATE and event.compression is not None and event.previous_change_number == self._feature_flag_storage.get_change_number(): + return True + + if event.update_type == UpdateType.RB_SEGMENT_UPDATE and event.compression is not None and event.previous_change_number == self._rule_based_segment_storage.get_change_number(): + return True + + return False + + def _run(self): + """Run worker handler.""" + while self.is_running(): + event = self._feature_flag_queue.get() + if not self.is_running(): + break + if event == self._centinel: + continue + + _LOGGER.debug('Processing feature flag update %d', event.change_number) + try: + if self._apply_iff_if_needed(event): + continue + + till = None + rbs_till = None + till, rbs_till = self._check_update_type(till, rbs_till, event) + sync_result = self._handler(till, rbs_till) + if not sync_result.success and sync_result.error_code is not None and sync_result.error_code == 414: + _LOGGER.error("URI too long exception caught, sync failed") + + if not sync_result.success: + _LOGGER.error("feature flags sync failed") + + except SplitStorageException as e: # pylint: disable=broad-except + _LOGGER.error('Exception Updating Feature Flag') + _LOGGER.debug('Exception information: ', exc_info=True) + except Exception as e: # pylint: disable=broad-except + _LOGGER.error('Exception raised in feature flag synchronization') + _LOGGER.debug('Exception information: ', exc_info=True) + + def _check_update_type(self, till, rbs_till, event): + if event.update_type == UpdateType.SPLIT_UPDATE: + till = event.change_number + else: + rbs_till = event.change_number + + return till, rbs_till + + def start(self): + """Start worker.""" + if self.is_running(): + _LOGGER.debug('Worker is already running') + return + self._running = True + + _LOGGER.debug('Starting Feature Flag Worker') + self._worker = threading.Thread(target=self._run, name='PushFeatureFlagWorker', daemon=True) + self._worker.start() + + def stop(self): + """Stop worker.""" + _LOGGER.debug('Stopping Feature Flag Worker') + if not self.is_running(): + _LOGGER.debug('Worker is not running') + return + self._running = False + self._feature_flag_queue.put(self._centinel) + +class SplitWorkerAsync(WorkerBase): + """Split Worker for processing updates.""" + + _centinel = object() + + def __init__(self, synchronize_feature_flag, synchronize_segment, feature_flag_queue, feature_flag_storage, segment_storage, telemetry_runtime_producer, rule_based_segment_storage): + """ + Class constructor. + + :param synchronize_feature_flag: handler to perform feature_flag synchronization on incoming event + :type synchronize_feature_flag: callable + :param synchronize_segment: handler to perform segment synchronization on incoming event + :type synchronize_segment: function + :param feature_flag_queue: queue with feature_flag updates notifications + :type feature_flag_queue: queue + :param feature_flag_storage: feature flag storage instance + :type feature_flag_storage: splitio.storage.inmemory.InMemorySplitStorage + :param segment_storage: segment storage instance + :type segment_storage: splitio.storage.inmemory.InMemorySegmentStorage + :param telemetry_runtime_producer: Telemetry runtime producer instance + :type telemetry_runtime_producer: splitio.engine.telemetry.TelemetryRuntimeProducer + :param rule_based_segment_storage: Rule based segment Storage. + :type rule_based_segment_storage: splitio.storage.InMemoryRuleBasedStorage + """ + self._feature_flag_queue = feature_flag_queue + self._handler = synchronize_feature_flag + self._segment_handler = synchronize_segment + self._running = False + self._feature_flag_storage = feature_flag_storage + self._segment_storage = segment_storage + self._telemetry_runtime_producer = telemetry_runtime_producer + self._rule_based_segment_storage = rule_based_segment_storage + + def is_running(self): + """Return whether the working is running.""" + return self._running + + async def _apply_iff_if_needed(self, event): + if not await self._check_instant_ff_update(event): + return False + try: + if event.update_type == UpdateType.SPLIT_UPDATE: + new_feature_flag = from_raw(json.loads(self._get_object_definition(event))) + segment_list = await update_feature_flag_storage_async(self._feature_flag_storage, [new_feature_flag], event.change_number) + for segment_name in segment_list: + if await self._segment_storage.get(segment_name) is None: + _LOGGER.debug(self._fetching_segment.format(segment_name=segment_name)) + await self._segment_handler(segment_name, event.change_number) + + referenced_rbs = self._get_referenced_rbs(new_feature_flag) + await self._fetch_rbs_segment_if_needed(referenced_rbs, event) + await self._telemetry_runtime_producer.record_update_from_sse(UpdateFromSSE.SPLIT_UPDATE) + else: + new_rbs = rbs_from_raw(json.loads(self._get_object_definition(event))) + segment_list = await update_rule_based_segment_storage_async(self._rule_based_segment_storage, [new_rbs], event.change_number) + for segment_name in segment_list: + if await self._segment_storage.get(segment_name) is None: + _LOGGER.debug(self._fetching_segment.format(segment_name=segment_name)) + await self._segment_handler(segment_name, event.change_number) + await self._telemetry_runtime_producer.record_update_from_sse(UpdateFromSSE.RBS_UPDATE) + return True + + except Exception as e: + raise SplitStorageException(e) + + async def _fetch_rbs_segment_if_needed(self, referenced_rbs, event): + if len(referenced_rbs) > 0 and not await self._rule_based_segment_storage.contains(referenced_rbs): + _LOGGER.debug('Fetching new rule based segment(s) %s', referenced_rbs) + await self._handler(None, event.change_number) + + async def _check_instant_ff_update(self, event): + if event.update_type == UpdateType.SPLIT_UPDATE and event.compression is not None and event.previous_change_number == await self._feature_flag_storage.get_change_number(): + return True + + if event.update_type == UpdateType.RB_SEGMENT_UPDATE and event.compression is not None and event.previous_change_number == await self._rule_based_segment_storage.get_change_number(): + return True + + return False + + async def _run(self): + """Run worker handler.""" + while self.is_running(): + event = await self._feature_flag_queue.get() + if not self.is_running(): + break + if event == self._centinel: + continue + _LOGGER.debug('Processing split_update %d', event.change_number) + try: + if await self._apply_iff_if_needed(event): + continue + till = None + rbs_till = None + if event.update_type == UpdateType.SPLIT_UPDATE: + till = event.change_number + else: + rbs_till = event.change_number + await self._handler(till, rbs_till) + except SplitStorageException as e: # pylint: disable=broad-except + _LOGGER.error('Exception Updating Feature Flag') + _LOGGER.debug('Exception information: ', exc_info=True) + except Exception as e: # pylint: disable=broad-except + _LOGGER.error('Exception raised in split synchronization') + _LOGGER.debug('Exception information: ', exc_info=True) + + def start(self): + """Start worker.""" + if self.is_running(): + _LOGGER.debug('Worker is already running') + return + self._running = True + + _LOGGER.debug('Starting Split Worker') + asyncio.get_running_loop().create_task(self._run()) + + async def stop(self): + """Stop worker.""" + _LOGGER.debug('Stopping Split Worker') + if not self.is_running(): + _LOGGER.debug('Worker is not running') + return + self._running = False + await self._feature_flag_queue.put(self._centinel) diff --git a/splitio/recorder/recorder.py b/splitio/recorder/recorder.py index 42ac2082..4c0ec155 100644 --- a/splitio/recorder/recorder.py +++ b/splitio/recorder/recorder.py @@ -3,9 +3,11 @@ import logging import random - from splitio.client.config import DEFAULT_DATA_SAMPLING - +from splitio.client.listener import ImpressionListenerException +from splitio.models.telemetry import MethodExceptionsAndLatencies +from splitio.models import telemetry +from splitio.optional.loaders import asyncio _LOGGER = logging.getLogger(__name__) @@ -13,6 +15,28 @@ class StatsRecorder(object, metaclass=abc.ABCMeta): """StatsRecorder interface.""" + def __init__(self, impressions_manager, event_storage, impression_storage, listener=None, unique_keys_tracker=None, imp_counter=None): + """ + Class constructor. + + :param impressions_manager: impression manager instance + :type impressions_manager: splitio.engine.impressions.Manager + :param event_storage: event storage instance + :type event_storage: splitio.storage.EventStorage + :param impression_storage: impression storage instance + :type impression_storage: splitio.storage.ImpressionStorage + :param unique_keys_tracker: Unique Keys Tracker instance + :type unique_keys_tracker: splitio.engine.unique_keys_tracker.UniqueKeysTracker + :param imp_counter: Impressions Counter instance + :type imp_counter: splitio.engine.impressions.Counter + """ + self._impressions_manager = impressions_manager + self._event_sotrage = event_storage + self._impression_storage = impression_storage + self._listener = listener + self._unique_keys_tracker = unique_keys_tracker + self._imp_counter = imp_counter + @abc.abstractmethod def record_treatment_stats(self, impressions, latency, operation): """ @@ -37,11 +61,44 @@ def record_track_stats(self, events): """ pass +class StatsRecorderThreadingBase(StatsRecorder): + """StandardRecorder class.""" + + def __init__(self, impressions_manager, event_storage, impression_storage, listener=None, unique_keys_tracker=None, imp_counter=None): + """ + Class constructor. + + :param impressions_manager: impression manager instance + :type impressions_manager: splitio.engine.impressions.Manager + :param event_storage: event storage instance + :type event_storage: splitio.storage.EventStorage + :param impression_storage: impression storage instance + :type impression_storage: splitio.storage.ImpressionStorage + :param unique_keys_tracker: Unique Keys Tracker instance + :type unique_keys_tracker: splitio.engine.unique_keys_tracker.UniqueKeysTracker + :param imp_counter: Impressions Counter instance + :type imp_counter: splitio.engine.impressions.Counter + """ + StatsRecorder.__init__(self, impressions_manager, event_storage, impression_storage, listener, unique_keys_tracker, imp_counter) -class StandardRecorder(StatsRecorder): + def _send_impressions_to_listener(self, impressions): + """ + Send impression result to custom listener. + + :param impressions: List of impression objects with attributes + :type impressions: list[tuple[splitio.models.impression.Impression, dict]] + """ + if self._listener is not None: + try: + for impression, attributes in impressions: + self._listener.log_impression(impression, attributes) + except ImpressionListenerException: + pass + +class StatsRecorderAsyncBase(StatsRecorder): """StandardRecorder class.""" - def __init__(self, impressions_manager, event_storage, impression_storage): + def __init__(self, impressions_manager, event_storage, impression_storage, listener=None, unique_keys_tracker=None, imp_counter=None): """ Class constructor. @@ -51,12 +108,50 @@ def __init__(self, impressions_manager, event_storage, impression_storage): :type event_storage: splitio.storage.EventStorage :param impression_storage: impression storage instance :type impression_storage: splitio.storage.ImpressionStorage + :param unique_keys_tracker: Unique Keys Tracker instance + :type unique_keys_tracker: splitio.engine.unique_keys_tracker.UniqueKeysTracker + :param imp_counter: Impressions Counter instance + :type imp_counter: splitio.engine.impressions.Counter """ - self._impressions_manager = impressions_manager - self._event_sotrage = event_storage - self._impression_storage = impression_storage + StatsRecorder.__init__(self, impressions_manager, event_storage, impression_storage, listener, unique_keys_tracker, imp_counter) - def record_treatment_stats(self, impressions, latency, operation): + async def _send_impressions_to_listener_async(self, impressions): + """ + Send impression result to custom listener. + + :param impressions: List of impression objects with attributes + :type impressions: list[tuple[splitio.models.impression.Impression, dict]] + """ + if self._listener is not None: + try: + for impression, attributes in impressions: + await self._listener.log_impression(impression, attributes) + except ImpressionListenerException: + pass + +class StandardRecorder(StatsRecorderThreadingBase): + """StandardRecorder class.""" + + def __init__(self, impressions_manager, event_storage, impression_storage, telemetry_evaluation_producer, telemetry_runtime_producer, listener=None, unique_keys_tracker=None, imp_counter=None): + """ + Class constructor. + + :param impressions_manager: impression manager instance + :type impressions_manager: splitio.engine.impressions.Manager + :param event_storage: event storage instance + :type event_storage: splitio.storage.EventStorage + :param impression_storage: impression storage instance + :type impression_storage: splitio.storage.ImpressionStorage + :param unique_keys_tracker: Unique Keys Tracker instance + :type unique_keys_tracker: splitio.engine.unique_keys_tracker.UniqueKeysTracker + :param imp_counter: Impressions Counter instance + :type imp_counter: splitio.engine.impressions.Counter + """ + StatsRecorderThreadingBase.__init__(self, impressions_manager, event_storage, impression_storage, listener, unique_keys_tracker, imp_counter) + self._telemetry_evaluation_producer = telemetry_evaluation_producer + self._telemetry_runtime_producer = telemetry_runtime_producer + + def record_treatment_stats(self, impressions_decorated, latency, operation, method_name): """ Record stats for treatment evaluation. @@ -68,27 +163,97 @@ def record_treatment_stats(self, impressions, latency, operation): :type operation: str """ try: - impressions = self._impressions_manager.process_impressions(impressions) + if method_name is not None: + self._telemetry_evaluation_producer.record_latency(operation, latency) + impressions, deduped, for_listener, for_counter, for_unique_keys_tracker = self._impressions_manager.process_impressions(impressions_decorated) + if deduped > 0: + self._telemetry_runtime_producer.record_impression_stats(telemetry.CounterConstants.IMPRESSIONS_DEDUPED, deduped) self._impression_storage.put(impressions) + self._send_impressions_to_listener(for_listener) + if len(for_counter) > 0: + self._imp_counter.track(for_counter) + if len(for_unique_keys_tracker) > 0: + [self._unique_keys_tracker.track(item[0], item[1]) for item in for_unique_keys_tracker] except Exception: # pylint: disable=broad-except _LOGGER.error('Error recording impressions') _LOGGER.debug('Error: ', exc_info=True) - def record_track_stats(self, event): + def record_track_stats(self, event, latency): """ Record stats for tracking events. :param event: events tracked :type event: splitio.models.events.EventWrapper """ + self._telemetry_evaluation_producer.record_latency(MethodExceptionsAndLatencies.TRACK, latency) return self._event_sotrage.put(event) +class StandardRecorderAsync(StatsRecorderAsyncBase): + """StandardRecorder async class.""" + + def __init__(self, impressions_manager, event_storage, impression_storage, telemetry_evaluation_producer, telemetry_runtime_producer, listener=None, unique_keys_tracker=None, imp_counter=None): + """ + Class constructor. + + :param impressions_manager: impression manager instance + :type impressions_manager: splitio.engine.impressions.Manager + :param event_storage: event storage instance + :type event_storage: splitio.storage.EventStorage + :param impression_storage: impression storage instance + :type impression_storage: splitio.storage.ImpressionStorage + :param unique_keys_tracker: Unique Keys Tracker instance + :type unique_keys_tracker: splitio.engine.unique_keys_tracker.UniqueKeysTrackerAsync + :param imp_counter: Impressions Counter instance + :type imp_counter: splitio.engine.impressions.Counter + """ + StatsRecorderAsyncBase.__init__(self, impressions_manager, event_storage, impression_storage, listener, unique_keys_tracker, imp_counter) + self._telemetry_evaluation_producer = telemetry_evaluation_producer + self._telemetry_runtime_producer = telemetry_runtime_producer + + async def record_treatment_stats(self, impressions_decorated, latency, operation, method_name): + """ + Record stats for treatment evaluation. + + :param impressions: impressions generated for each evaluation performed + :type impressions: array + :param latency: time took for doing evaluation + :type latency: int + :param operation: operation type + :type operation: str + """ + try: + if method_name is not None: + await self._telemetry_evaluation_producer.record_latency(operation, latency) + impressions, deduped, for_listener, for_counter, for_unique_keys_tracker = self._impressions_manager.process_impressions(impressions_decorated) + if deduped > 0: + await self._telemetry_runtime_producer.record_impression_stats(telemetry.CounterConstants.IMPRESSIONS_DEDUPED, deduped) + + await self._impression_storage.put(impressions) + await self._send_impressions_to_listener_async(for_listener) + if len(for_counter) > 0: + self._imp_counter.track(for_counter) + if len(for_unique_keys_tracker) > 0: + unique_keys_coros = [self._unique_keys_tracker.track(item[0], item[1]) for item in for_unique_keys_tracker] + await asyncio.gather(*unique_keys_coros) + except Exception: # pylint: disable=broad-except + _LOGGER.error('Error recording impressions') + _LOGGER.debug('Error: ', exc_info=True) + + async def record_track_stats(self, event, latency): + """ + Record stats for tracking events. -class PipelinedRecorder(StatsRecorder): + :param event: events tracked + :type event: splitio.models.events.EventWrapper + """ + await self._telemetry_evaluation_producer.record_latency(MethodExceptionsAndLatencies.TRACK, latency) + return await self._event_sotrage.put(event) + +class PipelinedRecorder(StatsRecorderThreadingBase): """PipelinedRecorder class.""" def __init__(self, pipe, impressions_manager, event_storage, - impression_storage, data_sampling=DEFAULT_DATA_SAMPLING): + impression_storage, telemetry_redis_storage, data_sampling=DEFAULT_DATA_SAMPLING, listener=None, unique_keys_tracker=None, imp_counter=None): """ Class constructor. @@ -102,14 +267,17 @@ def __init__(self, pipe, impressions_manager, event_storage, :type impression_storage: splitio.storage.redis.RedisImpressionsStorage :param data_sampling: data sampling factor :type data_sampling: number + :param unique_keys_tracker: Unique Keys Tracker instance + :type unique_keys_tracker: splitio.engine.unique_keys_tracker.UniqueKeysTracker + :param imp_counter: Impressions Counter instance + :type imp_counter: splitio.engine.impressions.Counter """ + StatsRecorderThreadingBase.__init__(self, impressions_manager, event_storage, impression_storage, listener, unique_keys_tracker, imp_counter) self._make_pipe = pipe - self._impressions_manager = impressions_manager - self._event_sotrage = event_storage - self._impression_storage = impression_storage self._data_sampling = data_sampling + self._telemetry_redis_storage = telemetry_redis_storage - def record_treatment_stats(self, impressions, latency, operation): + def record_treatment_stats(self, impressions_decorated, latency, operation, method_name): """ Record stats for treatment evaluation. @@ -121,30 +289,143 @@ def record_treatment_stats(self, impressions, latency, operation): :type operation: str """ try: - # TODO @matias.melograno - # Changing logic until TelemetryV2 released to avoid using pipelined operations - # Deprecated Old Telemetry if self._data_sampling < DEFAULT_DATA_SAMPLING: rnumber = random.uniform(0, 1) if self._data_sampling < rnumber: return - impressions = self._impressions_manager.process_impressions(impressions) - # pipe = self._make_pipe() - # self._impression_storage.add_impressions_to_pipe(impressions, pipe) - # self._telemetry_storage.add_latency_to_pipe(operation, latency, pipe) - # result = pipe.execute() - # if len(result) == 2: - # self._impression_storage.expire_key(result[0], len(impressions)) - self._impression_storage.put(impressions) + + impressions, deduped, for_listener, for_counter, for_unique_keys_tracker = self._impressions_manager.process_impressions(impressions_decorated) + if impressions: + pipe = self._make_pipe() + self._impression_storage.add_impressions_to_pipe(impressions, pipe) + if method_name is not None: + self._telemetry_redis_storage.add_latency_to_pipe(operation, latency, pipe) + result = pipe.execute() + if len(result) == 2: + self._impression_storage.expire_key(result[0], len(impressions)) + self._telemetry_redis_storage.expire_latency_keys(result[1], latency) + self._send_impressions_to_listener(for_listener) + + if len(for_counter) > 0: + self._imp_counter.track(for_counter) + if len(for_unique_keys_tracker) > 0: + [self._unique_keys_tracker.track(item[0], item[1]) for item in for_unique_keys_tracker] except Exception: # pylint: disable=broad-except _LOGGER.error('Error recording impressions') _LOGGER.debug('Error: ', exc_info=True) - def record_track_stats(self, event): + def record_track_stats(self, event, latency): """ Record stats for tracking events. :param event: events tracked :type event: splitio.models.events.EventWrapper """ - return self._event_sotrage.put(event) + try: + pipe = self._make_pipe() + self._event_sotrage.add_events_to_pipe(event, pipe) + self._telemetry_redis_storage.add_latency_to_pipe(MethodExceptionsAndLatencies.TRACK, latency, pipe) + result = pipe.execute() + if len(result) == 2: + self._event_sotrage.expire_keys(result[0], len(event)) + self._telemetry_redis_storage.expire_latency_keys(result[1], latency) + if result[0] > 0: + return True + + return False + + except Exception: # pylint: disable=broad-except + _LOGGER.error('Error recording events') + _LOGGER.debug('Error: ', exc_info=True) + return False + +class PipelinedRecorderAsync(StatsRecorderAsyncBase): + """PipelinedRecorder async class.""" + + def __init__(self, pipe, impressions_manager, event_storage, + impression_storage, telemetry_redis_storage, data_sampling=DEFAULT_DATA_SAMPLING, listener=None, unique_keys_tracker=None, imp_counter=None): + """ + Class constructor. + + :param pipe: redis pipeline function + :type pipe: callable + :param impressions_manager: impression manager instance + :type impressions_manager: splitio.engine.impressions.Manager + :param event_storage: event storage instance + :type event_storage: splitio.storage.EventStorage + :param impression_storage: impression storage instance + :type impression_storage: splitio.storage.redis.RedisImpressionsStorage + :param data_sampling: data sampling factor + :type data_sampling: number + :param unique_keys_tracker: Unique Keys Tracker instance + :type unique_keys_tracker: splitio.engine.unique_keys_tracker.UniqueKeysTrackerAsync + :param imp_counter: Impressions Counter instance + :type imp_counter: splitio.engine.impressions.Counter + """ + StatsRecorderAsyncBase.__init__(self, impressions_manager, event_storage, impression_storage, listener, unique_keys_tracker, imp_counter) + self._make_pipe = pipe + self._data_sampling = data_sampling + self._telemetry_redis_storage = telemetry_redis_storage + + async def record_treatment_stats(self, impressions_decorated, latency, operation, method_name): + """ + Record stats for treatment evaluation. + + :param impressions: impressions generated for each evaluation performed + :type impressions: array + :param latency: time took for doing evaluation + :type latency: int + :param operation: operation type + :type operation: str + """ + try: + if self._data_sampling < DEFAULT_DATA_SAMPLING: + rnumber = random.uniform(0, 1) + if self._data_sampling < rnumber: + return + + impressions, deduped, for_listener, for_counter, for_unique_keys_tracker = self._impressions_manager.process_impressions(impressions_decorated) + if impressions: + pipe = self._make_pipe() + self._impression_storage.add_impressions_to_pipe(impressions, pipe) + if method_name is not None: + self._telemetry_redis_storage.add_latency_to_pipe(operation, latency, pipe) + result = await pipe.execute() + if len(result) == 2: + await self._impression_storage.expire_key(result[0], len(impressions)) + await self._telemetry_redis_storage.expire_latency_keys(result[1], latency) + await self._send_impressions_to_listener_async(for_listener) + + if len(for_counter) > 0: + self._imp_counter.track(for_counter) + if len(for_unique_keys_tracker) > 0: + unique_keys_coros = [self._unique_keys_tracker.track(item[0], item[1]) for item in for_unique_keys_tracker] + await asyncio.gather(*unique_keys_coros) + except Exception: # pylint: disable=broad-except + _LOGGER.error('Error recording impressions') + _LOGGER.debug('Error: ', exc_info=True) + + async def record_track_stats(self, event, latency): + """ + Record stats for tracking events. + + :param event: events tracked + :type event: splitio.models.events.EventWrapper + """ + try: + pipe = self._make_pipe() + self._event_sotrage.add_events_to_pipe(event, pipe) + self._telemetry_redis_storage.add_latency_to_pipe(MethodExceptionsAndLatencies.TRACK, latency, pipe) + result = await pipe.execute() + if len(result) == 2: + await self._event_sotrage.expire_keys(result[0], len(event)) + await self._telemetry_redis_storage.expire_latency_keys(result[1], latency) + if result[0] > 0: + return True + + return False + + except Exception: # pylint: disable=broad-except + _LOGGER.error('Error recording events') + _LOGGER.debug('Error: ', exc_info=True) + return False diff --git a/splitio/spec.py b/splitio/spec.py new file mode 100644 index 00000000..cd7588e0 --- /dev/null +++ b/splitio/spec.py @@ -0,0 +1 @@ +SPEC_VERSION = '1.3' diff --git a/splitio/storage/__init__.py b/splitio/storage/__init__.py index 23ccda31..079ee863 100644 --- a/splitio/storage/__init__.py +++ b/splitio/storage/__init__.py @@ -1,7 +1,6 @@ """Base storage interfaces.""" import abc - class SplitStorage(object, metaclass=abc.ABCMeta): """Split storage interface implemented as an abstract class.""" @@ -30,25 +29,16 @@ def fetch_many(self, split_names): pass @abc.abstractmethod - def put(self, split): - """ - Store a split. - - :param split: Split object to store - :type split_name: splitio.models.splits.Split - """ - pass - - @abc.abstractmethod - def remove(self, split_name): + def update(self, to_add, to_delete, new_change_number): """ - Remove a split from storage. + Update feature flag storage. - :param split_name: Name of the feature to remove. - :type split_name: str - - :return: True if the split was found and removed. False otherwise. - :rtype: bool + :param to_add: List of feature flags to add + :type to_add: list[splitio.models.splits.Split] + :param to_delete: List of feature flags to delete + :type to_delete: list[splitio.models.splits.Split] + :param new_change_number: New change number. + :type new_change_number: int """ pass @@ -61,16 +51,6 @@ def get_change_number(self): """ pass - @abc.abstractmethod - def set_change_number(self, new_change_number): - """ - Set the latest change number. - - :param new_change_number: New change number. - :type new_change_number: int - """ - pass - @abc.abstractmethod def get_split_names(self): """ @@ -283,3 +263,166 @@ def clear(self): Clear data. """ pass + +class TelemetryStorage(object, metaclass=abc.ABCMeta): + """Telemetry storage interface.""" + + @abc.abstractmethod + def record_config(self, config): + """ + initilize telemetry objects + + :param congif: factory configuration parameters + :type config: splitio.client.config + """ + pass + + @abc.abstractmethod + def record_latency(self, method, latency): + """ + record latency data + + :param method: method name + :type method: string + :param latency: latency + :type latency: int64 + """ + pass + + @abc.abstractmethod + def record_exception(self, method): + """ + record an exception + + :param method: method name + :type method: string + """ + pass + + @abc.abstractmethod + def record_not_ready_usage(self): + """ + record not ready time + + """ + pass + + @abc.abstractmethod + def record_bur_time_out(self): + """ + record BUR timeouts + + """ + pass + +class FlagSetsFilter(object): + """Config Flagsets Filter storage.""" + + def __init__(self, flag_sets=[]): + """Constructor.""" + self.flag_sets = set(flag_sets) + self.should_filter = any(flag_sets) + self.sorted_flag_sets = sorted(flag_sets) + + def set_exist(self, flag_set): + """ + Check if a flagset exist in flagset filter + :param flag_set: set name + :type flag_set: str + + :rtype: bool + """ + if not self.should_filter: + return True + + if not isinstance(flag_set, str) or flag_set == '': + return False + + return any(self.flag_sets.intersection(set([flag_set]))) + + def intersect(self, flag_sets): + """ + Check if a set exist in config flagset filter + :param flag_set: set of flagsets + :type flag_set: set + + :rtype: bool + """ + if not self.should_filter: + return True + + if not isinstance(flag_sets, set) or len(flag_sets) == 0: + return False + + return any(self.flag_sets.intersection(flag_sets)) + +class RuleBasedSegmentsStorage(object, metaclass=abc.ABCMeta): + """SplitRule based segment storage interface implemented as an abstract class.""" + + @abc.abstractmethod + def get(self, segment_name): + """ + Retrieve a rule based segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str + + :rtype: str + """ + pass + + @abc.abstractmethod + def update(self, to_add, to_delete, new_change_number): + """ + Update rule based segment.. + + :param to_add: List of rule based segment. to add + :type to_add: list[splitio.models.rule_based_segments.RuleBasedSegment] + :param to_delete: List of rule based segment. to delete + :type to_delete: list[splitio.models.rule_based_segments.RuleBasedSegment] + :param new_change_number: New change number. + :type new_change_number: int + """ + pass + + @abc.abstractmethod + def get_change_number(self): + """ + Retrieve latest rule based segment change number. + + :rtype: int + """ + pass + + @abc.abstractmethod + def contains(self, segment_names): + """ + Return whether the segments exists in rule based segment in cache. + + :param segment_names: segment name to validate. + :type segment_names: str + + :return: True if segment names exists. False otherwise. + :rtype: bool + """ + pass + + @abc.abstractmethod + def get_segment_names(self): + """ + Retrieve a list of all excluded segments names. + + :return: List of segment names. + :rtype: list(str) + """ + pass + + @abc.abstractmethod + def get_large_segment_names(self): + """ + Retrieve a list of all excluded large segments names. + + :return: List of segment names. + :rtype: list(str) + """ + pass \ No newline at end of file diff --git a/splitio/storage/adapters/cache_trait.py b/splitio/storage/adapters/cache_trait.py index 399ee383..0e24d050 100644 --- a/splitio/storage/adapters/cache_trait.py +++ b/splitio/storage/adapters/cache_trait.py @@ -4,12 +4,13 @@ import time from functools import update_wrapper +from splitio.optional.loaders import asyncio DEFAULT_MAX_AGE = 5 DEFAULT_MAX_SIZE = 100 -class LocalMemoryCache(object): # pylint: disable=too-many-instance-attributes +class LocalMemoryCacheBase(object): # pylint: disable=too-many-instance-attributes """ Key/Value local memory cache. with expiration & LRU eviction. @@ -49,7 +50,6 @@ def __init__( ): """Class constructor.""" self._data = {} - self._lock = threading.Lock() self._max_age_seconds = max_age_seconds self._max_size = max_size self._lru = None @@ -57,41 +57,6 @@ def __init__( self._key_func = key_func self._user_func = user_func - def get(self, *args, **kwargs): - """ - Fetch an item from the cache. If it's a miss, call user function to refill. - - :param args: User supplied positional arguments - :type args: list - :param kwargs: User supplied keyword arguments - :type kwargs: dict - - :return: Cached/Fetched object - :rtype: object - """ - with self._lock: - key = self._key_func(*args, **kwargs) - node = self._data.get(key) - if node is not None: - if self._is_expired(node): - node.value = self._user_func(*args, **kwargs) - node.last_update = time.time() - else: - value = self._user_func(*args, **kwargs) - node = LocalMemoryCache._Node(key, value, time.time(), None, None) - node = self._bubble_up(node) - self._data[key] = node - self._rollover() - return node.value - - def remove_expired(self): - """Remove expired elements.""" - with self._lock: - self._data = { - key: value for (key, value) in self._data.items() - if not self._is_expired(value) - } - def clear(self): """Clear the cache.""" self._data = {} @@ -151,6 +116,106 @@ def __str__(self): node = node.previous return '\n' + '\n'.join(nodes) + '\n' +class LocalMemoryCache(LocalMemoryCacheBase): # pylint: disable=too-many-instance-attributes + """Local cache for threading""" + def __init__( + self, + key_func, + user_func, + max_age_seconds=DEFAULT_MAX_AGE, + max_size=DEFAULT_MAX_SIZE + ): + """Class constructor.""" + LocalMemoryCacheBase.__init__(self, key_func, user_func, max_age_seconds, max_size) + self._lock = threading.Lock() + + def get(self, *args, **kwargs): + """ + Fetch an item from the cache. If it's a miss, call user function to refill. + + :param args: User supplied positional arguments + :type args: list + :param kwargs: User supplied keyword arguments + :type kwargs: dict + + :return: Cached/Fetched object + :rtype: object + """ + with self._lock: + key = self._key_func(*args, **kwargs) + node = self._data.get(key) + if node is not None: + if self._is_expired(node): + node.value = self._user_func(*args, **kwargs) + node.last_update = time.time() + else: + value = self._user_func(*args, **kwargs) + node = LocalMemoryCache._Node(key, value, time.time(), None, None) + node = self._bubble_up(node) + self._data[key] = node + self._rollover() + return node.value + + + def remove_expired(self): + """Remove expired elements.""" + with self._lock: + self._data = { + key: value for (key, value) in self._data.items() + if not self._is_expired(value) + } + +class LocalMemoryCacheAsync(LocalMemoryCacheBase): # pylint: disable=too-many-instance-attributes + """Local cache for asyncio""" + def __init__( + self, + key_func, + user_func, + max_age_seconds=DEFAULT_MAX_AGE, + max_size=DEFAULT_MAX_SIZE + ): + """Class constructor.""" + LocalMemoryCacheBase.__init__(self, key_func, user_func, max_age_seconds, max_size) + self._lock = asyncio.Lock() + + async def get_key(self, key): + """ + Fetch an item from the cache, return None if does not exist + :param key: User supplied key + :type key: str/frozenset + :return: Cached/Fetched object + :rtype: object + """ + async with self._lock: + node = self._data.get(key) + if node is not None: + if self._is_expired(node): + return None + + if node is None: + return None + + node = self._bubble_up(node) + return node.value + + async def add_key(self, key, value): + """ + Add an item from the cache. + :param key: User supplied key + :type key: str/frozenset + :param value: key value + :type value: str + """ + async with self._lock: + if self._data.get(key) is not None: + node = self._data.get(key) + node.value = value + node.last_update = time.time() + else: + node = LocalMemoryCache._Node(key, value, time.time(), None, None) + node = self._bubble_up(node) + self._data[key] = node + self._rollover() def decorate(key_func, max_age_seconds=DEFAULT_MAX_AGE, max_size=DEFAULT_MAX_SIZE): """ diff --git a/splitio/storage/adapters/redis.py b/splitio/storage/adapters/redis.py index c0cf9e75..92aa2544 100644 --- a/splitio/storage/adapters/redis.py +++ b/splitio/storage/adapters/redis.py @@ -1,10 +1,12 @@ """Redis client wrapper with prefix support.""" from builtins import str - +import abc try: from redis import StrictRedis from redis.sentinel import Sentinel from redis.exceptions import RedisError + import redis.asyncio as aioredis + from redis.asyncio.sentinel import Sentinel as SentinelAsync except ImportError: def missing_redis_dependencies(*_, **__): """Fail if missing dependencies are used.""" @@ -12,8 +14,7 @@ def missing_redis_dependencies(*_, **__): 'Missing Redis support dependencies. ' 'Please use `pip install splitio_client[redis]` to install the sdk with redis support' ) - StrictRedis = Sentinel = missing_redis_dependencies - + StrictRedis = Sentinel = aioredis = missing_redis_dependencies class RedisAdapterException(Exception): """Exception to be thrown when a redis command fails with an exception.""" @@ -63,17 +64,20 @@ def add_prefix(self, k): if self._prefix: if isinstance(k, str): return '{prefix}.{key}'.format(prefix=self._prefix, key=k) + elif isinstance(k, list) and k: if isinstance(k[0], bytes): return [ '{prefix}.{key}'.format(prefix=self._prefix, key=key.decode("utf8")) for key in k ] + elif isinstance(k[0], str): return [ '{prefix}.{key}'.format(prefix=self._prefix, key=key) for key in k ] + else: return k @@ -94,8 +98,10 @@ def remove_prefix(self, k): if self._prefix: if isinstance(k, str): return k[len(self._prefix)+1:] + elif isinstance(k, list): return [key[len(self._prefix)+1:] for key in k] + else: return k @@ -103,8 +109,106 @@ def remove_prefix(self, k): "Cannot remove prefix correctly. Wrong type for key(s) provided" ) +class RedisAdapterBase(object, metaclass=abc.ABCMeta): + """Redis adapter template.""" + + @abc.abstractmethod + def keys(self, pattern): + """Mimic original redis keys.""" + + @abc.abstractmethod + def set(self, name, value, *args, **kwargs): + """Mimic original redis set.""" + + @abc.abstractmethod + def get(self, name): + """Mimic original redis get.""" + + @abc.abstractmethod + def setex(self, name, time, value): + """Mimic original redis setex.""" + + @abc.abstractmethod + def delete(self, *names): + """Mimic original redis delete.""" + + @abc.abstractmethod + def exists(self, name): + """Mimic original redis exists.""" + + @abc.abstractmethod + def lrange(self, key, start, end): + """Mimic original redis lrange.""" + + @abc.abstractmethod + def mget(self, names): + """Mimic original redis mget.""" -class RedisAdapter(object): # pylint: disable=too-many-public-methods + @abc.abstractmethod + def smembers(self, name): + """Mimic original redis smembers.""" + + @abc.abstractmethod + def sadd(self, name, *values): + """Mimic original redis sadd.""" + + @abc.abstractmethod + def srem(self, name, *values): + """Mimic original redis srem.""" + + @abc.abstractmethod + def sismember(self, name, value): + """Mimic original redis sismember.""" + + @abc.abstractmethod + def eval(self, script, number_of_keys, *keys): + """Mimic original redis eval.""" + + @abc.abstractmethod + def hset(self, name, key, value): + """Mimic original redis hset.""" + + @abc.abstractmethod + def hget(self, name, key): + """Mimic original redis hget.""" + + @abc.abstractmethod + def hincrby(self, name, key, amount=1): + """Mimic original redis hincrby.""" + + @abc.abstractmethod + def incr(self, name, amount=1): + """Mimic original redis incr.""" + + @abc.abstractmethod + def getset(self, name, value): + """Mimic original redis getset.""" + + @abc.abstractmethod + def rpush(self, key, *values): + """Mimic original redis rpush.""" + + @abc.abstractmethod + def expire(self, key, value): + """Mimic original redis expire.""" + + @abc.abstractmethod + def rpop(self, key): + """Mimic original redis rpop.""" + + @abc.abstractmethod + def ttl(self, key): + """Mimic original redis ttl.""" + + @abc.abstractmethod + def lpop(self, key): + """Mimic original redis lpop.""" + + @abc.abstractmethod + def pipeline(self): + """Mimic original redis pipeline.""" + +class RedisAdapter(RedisAdapterBase): # pylint: disable=too-many-public-methods """ Instance decorator for Redis clients such as StrictRedis. @@ -241,6 +345,13 @@ def hget(self, name, key): except RedisError as exc: raise RedisAdapterException('Error executing hget operation') from exc + def hincrby(self, name, key, amount=1): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.hincrby(self._prefix_helper.add_prefix(name), key, amount) + except RedisError as exc: + raise RedisAdapterException('Error executing hincrby operation') from exc + def incr(self, name, amount=1): """Mimic original redis function but using user custom prefix.""" try: @@ -297,10 +408,214 @@ def pipeline(self): except RedisError as exc: raise RedisAdapterException('Error executing ttl operation') from exc +class RedisAdapterAsync(RedisAdapterBase): # pylint: disable=too-many-public-methods + """ + Instance decorator for asyncio Redis clients such as StrictRedis. -class RedisPipelineAdapter(object): + Adds an extra layer handling addition/removal of user prefix when handling + keys """ - Instance decorator for Redis Pipeline. + def __init__(self, decorated, prefix=None): + """ + Store the user prefix and the redis client instance. + + :param decorated: Instance of redis cache client to decorate. + :param prefix: User prefix to add. + """ + self._decorated = decorated + self._prefix_helper = PrefixHelper(prefix) + + # Below starts a list of methods that implement the interface of a standard + # redis client. + + async def keys(self, pattern): + """Mimic original redis function but using user custom prefix.""" + try: + return [ + key + for key in self._prefix_helper.remove_prefix(await self._decorated.keys(self._prefix_helper.add_prefix(pattern))) + ] + except RedisError as exc: + raise RedisAdapterException('Failed to execute keys operation') from exc + + async def set(self, name, value, *args, **kwargs): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.set( + self._prefix_helper.add_prefix(name), value, *args, **kwargs + ) + except RedisError as exc: + raise RedisAdapterException('Failed to execute set operation') from exc + + async def get(self, name): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.get(self._prefix_helper.add_prefix(name)) + except RedisError as exc: + raise RedisAdapterException('Error executing get operation') from exc + + async def setex(self, name, time, value): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.setex(self._prefix_helper.add_prefix(name), time, value) + except RedisError as exc: + raise RedisAdapterException('Error executing setex operation') from exc + + async def delete(self, *names): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.delete(*self._prefix_helper.add_prefix(list(names))) + except RedisError as exc: + raise RedisAdapterException('Error executing delete operation') from exc + + async def exists(self, name): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.exists(self._prefix_helper.add_prefix(name)) + except RedisError as exc: + raise RedisAdapterException('Error executing exists operation') from exc + + async def lrange(self, key, start, end): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.lrange(self._prefix_helper.add_prefix(key), start, end) + except RedisError as exc: + raise RedisAdapterException('Error executing exists operation') from exc + + async def mget(self, names): + """Mimic original redis function but using user custom prefix.""" + try: + return [ + item + for item in await self._decorated.mget(self._prefix_helper.add_prefix(names)) + ] + except RedisError as exc: + raise RedisAdapterException('Error executing mget operation') from exc + + async def smembers(self, name): + """Mimic original redis function but using user custom prefix.""" + try: + return [ + item + for item in await self._decorated.smembers(self._prefix_helper.add_prefix(name)) + ] + except RedisError as exc: + raise RedisAdapterException('Error executing smembers operation') from exc + + async def sadd(self, name, *values): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.sadd(self._prefix_helper.add_prefix(name), *values) + except RedisError as exc: + raise RedisAdapterException('Error executing sadd operation') from exc + + async def srem(self, name, *values): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.srem(self._prefix_helper.add_prefix(name), *values) + except RedisError as exc: + raise RedisAdapterException('Error executing srem operation') from exc + + async def sismember(self, name, value): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.sismember(self._prefix_helper.add_prefix(name), value) + except RedisError as exc: + raise RedisAdapterException('Error executing sismember operation') from exc + + async def eval(self, script, number_of_keys, *keys): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.eval(script, number_of_keys, *self._prefix_helper.add_prefix(list(keys))) + except RedisError as exc: + raise RedisAdapterException('Error executing eval operation') from exc + + async def hset(self, name, key, value): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.hset(self._prefix_helper.add_prefix(name), key, value) + except RedisError as exc: + raise RedisAdapterException('Error executing hset operation') from exc + + async def hget(self, name, key): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.hget(self._prefix_helper.add_prefix(name), key) + except RedisError as exc: + raise RedisAdapterException('Error executing hget operation') from exc + + async def hincrby(self, name, key, amount=1): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.hincrby(self._prefix_helper.add_prefix(name), key, amount) + except RedisError as exc: + raise RedisAdapterException('Error executing hincrby operation') from exc + + async def incr(self, name, amount=1): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.incr(self._prefix_helper.add_prefix(name), amount) + except RedisError as exc: + raise RedisAdapterException('Error executing incr operation') from exc + + async def getset(self, name, value): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.getset(self._prefix_helper.add_prefix(name), value) + except RedisError as exc: + raise RedisAdapterException('Error executing getset operation') from exc + + async def rpush(self, key, *values): + """Mimic original redis function but using user custom prefix.""" + try: + async with self._decorated.client() as conn: + return await conn.rpush(self._prefix_helper.add_prefix(key), *values) + except RedisError as exc: + raise RedisAdapterException('Error executing rpush operation') from exc + + async def expire(self, key, value): + """Mimic original redis function but using user custom prefix.""" + try: + async with self._decorated.client() as conn: + return await conn.expire(self._prefix_helper.add_prefix(key), value) + except RedisError as exc: + raise RedisAdapterException('Error executing expire operation') from exc + + async def rpop(self, key): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.rpop(self._prefix_helper.add_prefix(key)) + except RedisError as exc: + raise RedisAdapterException('Error executing rpop operation') from exc + + async def ttl(self, key): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.ttl(self._prefix_helper.add_prefix(key)) + except RedisError as exc: + raise RedisAdapterException('Error executing ttl operation') from exc + + async def lpop(self, key): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.lpop(self._prefix_helper.add_prefix(key)) + except RedisError as exc: + raise RedisAdapterException('Error executing lpop operation') from exc + + def pipeline(self): + """Mimic original redis pipeline.""" + try: + return RedisPipelineAdapterAsync(self._decorated, self._prefix_helper) + except RedisError as exc: + raise RedisAdapterException('Error executing ttl operation') from exc + + async def close(self): + await self._decorated.close() + await self._decorated.connection_pool.disconnect(inuse_connections=True) + +class RedisPipelineAdapterBase(object): + """ + Base decorator for Redis Pipeline. Adds an extra layer handling addition/removal of user prefix when handling keys @@ -323,6 +638,30 @@ def incr(self, name, amount=1): """Mimic original redis function but using user custom prefix.""" self._pipe.incr(self._prefix_helper.add_prefix(name), amount) + def hincrby(self, name, key, amount=1): + """Mimic original redis function but using user custom prefix.""" + self._pipe.hincrby(self._prefix_helper.add_prefix(name), key, amount) + + def smembers(self, name): + """Mimic original redis function but using user custom prefix.""" + self._pipe.smembers(self._prefix_helper.add_prefix(name)) + +class RedisPipelineAdapter(RedisPipelineAdapterBase): + """ + Instance decorator for Redis Pipeline. + + Adds an extra layer handling addition/removal of user prefix when handling + keys + """ + def __init__(self, decorated, prefix_helper): + """ + Store the user prefix and the redis client instance. + + :param decorated: Instance of redis cache client to decorate. + :param _prefix_helper: PrefixHelper utility + """ + RedisPipelineAdapterBase.__init__(self, decorated, prefix_helper) + def execute(self): """Mimic original redis function but using user custom prefix.""" try: @@ -330,6 +669,28 @@ def execute(self): except RedisError as exc: raise RedisAdapterException('Error executing pipeline operation') from exc +class RedisPipelineAdapterAsync(RedisPipelineAdapterBase): + """ + Instance decorator for Asyncio Redis Pipeline. + + Adds an extra layer handling addition/removal of user prefix when handling + keys + """ + def __init__(self, decorated, prefix_helper): + """ + Store the user prefix and the redis client instance. + + :param decorated: Instance of redis cache client to decorate. + :param _prefix_helper: PrefixHelper utility + """ + RedisPipelineAdapterBase.__init__(self, decorated, prefix_helper) + + async def execute(self): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._pipe.execute() + except RedisError as exc: + raise RedisAdapterException('Error executing pipeline operation') from exc def _build_default_client(config): # pylint: disable=too-many-locals """ @@ -344,6 +705,7 @@ def _build_default_client(config): # pylint: disable=too-many-locals host = config.get('redisHost', 'localhost') port = config.get('redisPort', 6379) database = config.get('redisDb', 0) + username = config.get('redisUsername', None) password = config.get('redisPassword', None) socket_timeout = config.get('redisSocketTimeout', None) socket_connect_timeout = config.get('redisSocketConnectTimeout', None) @@ -353,7 +715,6 @@ def _build_default_client(config): # pylint: disable=too-many-locals unix_socket_path = config.get('redisUnixSocketPath', None) encoding = config.get('redisEncoding', 'utf-8') encoding_errors = config.get('redisEncodingErrors', 'strict') - errors = config.get('redisErrors', None) decode_responses = config.get('redisDecodeResponses', True) retry_on_timeout = config.get('redisRetryOnTimeout', False) ssl = config.get('redisSsl', False) @@ -369,6 +730,7 @@ def _build_default_client(config): # pylint: disable=too-many-locals port=port, db=database, password=password, + username=username, socket_timeout=socket_timeout, socket_connect_timeout=socket_connect_timeout, socket_keepalive=socket_keepalive, @@ -377,7 +739,6 @@ def _build_default_client(config): # pylint: disable=too-many-locals unix_socket_path=unix_socket_path, encoding=encoding, encoding_errors=encoding_errors, - errors=errors, decode_responses=decode_responses, retry_on_timeout=retry_on_timeout, ssl=ssl, @@ -389,6 +750,66 @@ def _build_default_client(config): # pylint: disable=too-many-locals ) return RedisAdapter(redis, prefix=prefix) +async def _build_default_client_async(config): # pylint: disable=too-many-locals + """ + Build a redis asyncio adapter. + + :param config: Redis configuration properties + :type config: dict + + :return: A wrapped Redis object + :rtype: splitio.storage.adapters.redis.RedisAdapterAsync + """ + host = config.get('redisHost', 'localhost') + port = config.get('redisPort', 6379) + database = config.get('redisDb', 0) + username = config.get('redisUsername', None) + password = config.get('redisPassword', None) + socket_timeout = config.get('redisSocketTimeout', None) + socket_connect_timeout = config.get('redisSocketConnectTimeout', None) + socket_keepalive = config.get('redisSocketKeepalive', None) + socket_keepalive_options = config.get('redisSocketKeepaliveOptions', None) + connection_pool = config.get('redisConnectionPool', None) + unix_socket_path = config.get('redisUnixSocketPath', None) + encoding = config.get('redisEncoding', 'utf-8') + encoding_errors = config.get('redisEncodingErrors', 'strict') + decode_responses = config.get('redisDecodeResponses', True) + retry_on_timeout = config.get('redisRetryOnTimeout', False) + ssl = config.get('redisSsl', False) + ssl_keyfile = config.get('redisSslKeyfile', None) + ssl_certfile = config.get('redisSslCertfile', None) + ssl_cert_reqs = config.get('redisSslCertReqs', None) + ssl_ca_certs = config.get('redisSslCaCerts', None) + max_connections = config.get('redisMaxConnections', None) + prefix = config.get('redisPrefix') + + if connection_pool == None: + connection_pool = aioredis.ConnectionPool.from_url( + "redis://" + host + ":" + str(port), + db=database, + password=password, + username=username, + max_connections=max_connections, + encoding=encoding, + decode_responses=decode_responses, + socket_timeout=socket_timeout, + ) + redis = aioredis.Redis( + connection_pool=connection_pool, + socket_connect_timeout=socket_connect_timeout, + socket_keepalive=socket_keepalive, + socket_keepalive_options=socket_keepalive_options, + unix_socket_path=unix_socket_path, + encoding_errors=encoding_errors, + retry_on_timeout=retry_on_timeout, + ssl=ssl, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs + ) + return RedisAdapterAsync(redis, prefix=prefix) + def _build_sentinel_client(config): # pylint: disable=too-many-locals """ @@ -422,6 +843,7 @@ def _build_sentinel_client(config): # pylint: disable=too-many-locals raise SentinelConfigurationException('redisMasterService must be specified.') database = config.get('redisDb', 0) + username = config.get('redisUsername', None) password = config.get('redisPassword', None) socket_timeout = config.get('redisSocketTimeout', None) socket_connect_timeout = config.get('redisSocketConnectTimeout', None) @@ -439,6 +861,7 @@ def _build_sentinel_client(config): # pylint: disable=too-many-locals sentinels, db=database, password=password, + username=username, socket_timeout=socket_timeout, socket_connect_timeout=socket_connect_timeout, socket_keepalive=socket_keepalive, @@ -454,6 +877,86 @@ def _build_sentinel_client(config): # pylint: disable=too-many-locals redis = sentinel.master_for(master_service) return RedisAdapter(redis, prefix=prefix) +async def _build_sentinel_client_async(config): # pylint: disable=too-many-locals + """ + Build a redis client with sentinel replication. + + :param config: Redis configuration properties. + :type config: dict + + :return: A Wrapped redis-sentinel client + :rtype: splitio.storage.adapters.redis.RedisAdapter + """ + sentinels = config.get('redisSentinels') + + if sentinels is None: + raise SentinelConfigurationException('redisSentinels must be specified.') + if not isinstance(sentinels, list): + raise SentinelConfigurationException('Sentinels must be an array of elements in the form of' + ' [(ip, port)].') + if not sentinels: + raise SentinelConfigurationException('It must be at least one sentinel.') + if not all(isinstance(s, tuple) for s in sentinels): + raise SentinelConfigurationException('Sentinels must respect the tuple structure' + '[(ip, port)].') + + master_service = config.get('redisMasterService') + + if master_service is None: + raise SentinelConfigurationException('redisMasterService must be specified.') + + database = config.get('redisDb', 0) + password = config.get('redisPassword', None) + socket_timeout = config.get('redisSocketTimeout', None) + socket_connect_timeout = config.get('redisSocketConnectTimeout', None) + socket_keepalive = config.get('redisSocketKeepalive', None) + socket_keepalive_options = config.get('redisSocketKeepaliveOptions', None) + connection_pool = config.get('redisConnectionPool', None) + encoding = config.get('redisEncoding', 'utf-8') + encoding_errors = config.get('redisEncodingErrors', 'strict') + decode_responses = config.get('redisDecodeResponses', True) + retry_on_timeout = config.get('redisRetryOnTimeout', False) + max_connections = config.get('redisMaxConnections', None) + ssl = config.get('redisSsl', False) + prefix = config.get('redisPrefix') + + sentinel = SentinelAsync( + sentinels, + db=database, + password=password, + encoding=encoding, + encoding_errors=encoding_errors, + decode_responses=decode_responses, + max_connections=max_connections, + connection_pool=connection_pool, + socket_connect_timeout=socket_connect_timeout + ) + + redis = sentinel.master_for( + master_service, + socket_timeout=socket_timeout, + socket_keepalive=socket_keepalive, + socket_keepalive_options=socket_keepalive_options, + encoding_errors=encoding_errors, + retry_on_timeout=retry_on_timeout, + ssl=ssl + ) + return RedisAdapterAsync(redis, prefix=prefix) + +async def build_async(config): + """ + Build a async redis storage according to the configuration received. + + :param config: SDK Configuration parameters with redis properties. + :type config: dict. + + :return: A redis async client + :rtype: splitio.storage.adapters.redis.RedisAdapterAsync + """ + if 'redisSentinels' in config: + return await _build_sentinel_client_async(config) + + return await _build_default_client_async(config) def build(config): """ @@ -467,4 +970,5 @@ def build(config): """ if 'redisSentinels' in config: return _build_sentinel_client(config) + return _build_default_client(config) diff --git a/splitio/storage/inmemmory.py b/splitio/storage/inmemmory.py index ab0b5176..db71f7fd 100644 --- a/splitio/storage/inmemmory.py +++ b/splitio/storage/inmemmory.py @@ -5,91 +5,200 @@ from collections import Counter from splitio.models.segments import Segment -from splitio.storage import SplitStorage, SegmentStorage, ImpressionStorage, EventStorage +from splitio.models.telemetry import HTTPErrors, HTTPLatencies, MethodExceptions, MethodLatencies, LastSynchronization, StreamingEvents, TelemetryConfig, TelemetryCounters, CounterConstants, \ + HTTPErrorsAsync, HTTPLatenciesAsync, MethodExceptionsAsync, MethodLatenciesAsync, LastSynchronizationAsync, StreamingEventsAsync, TelemetryConfigAsync, TelemetryCountersAsync +from splitio.models.events import SdkInternalEvent +from splitio.events.events_metadata import EventsMetadata, SdkEventType +from splitio.models.notification import SdkInternalEventNotification +from splitio.storage import FlagSetsFilter, SplitStorage, SegmentStorage, ImpressionStorage, EventStorage, TelemetryStorage, RuleBasedSegmentsStorage +from splitio.optional.loaders import asyncio MAX_SIZE_BYTES = 5 * 1024 * 1024 - +MAX_TAGS = 10 _LOGGER = logging.getLogger(__name__) +class FlagSets(object): + """InMemory Flagsets storage.""" -class InMemorySplitStorage(SplitStorage): - """InMemory implementation of a split storage.""" - - def __init__(self): + def __init__(self, flag_sets=[]): """Constructor.""" + self.sets_feature_flag_map = {} self._lock = threading.RLock() - self._splits = {} - self._change_number = -1 - self._traffic_types = Counter() + for flag_set in flag_sets: + self.sets_feature_flag_map[flag_set] = set() + + def flag_set_exist(self, flag_set): + """ + Check if a flagset exist in stored flagset + :param flag_set: set name + :type flag_set: str + + :rtype: bool + """ + with self._lock: + return flag_set in self.sets_feature_flag_map.keys() - def get(self, split_name): + def get_flag_set(self, flag_set): """ - Retrieve a split. + fetch feature flags stored in a flag set + :param flag_set: set name + :type flag_set: str - :param split_name: Name of the feature to fetch. - :type split_name: str + :rtype: list(str) + """ + with self._lock: + return self.sets_feature_flag_map.get(flag_set) - :rtype: splitio.models.splits.Split + def _add_flag_set(self, flag_set): + """ + Add new flag set to storage + :param flag_set: set name + :type flag_set: str + """ + with self._lock: + if not self.flag_set_exist(flag_set): + self.sets_feature_flag_map[flag_set] = set() + + def _remove_flag_set(self, flag_set): + """ + Remove existing flag set from storage + :param flag_set: set name + :type flag_set: str + """ + with self._lock: + if self.flag_set_exist(flag_set): + del self.sets_feature_flag_map[flag_set] + + def add_feature_flag_to_flag_set(self, flag_set, feature_flag): + """ + Add a feature flag to existing flag set + :param flag_set: set name + :type flag_set: str + :param feature_flag: feature flag name + :type feature_flag: str + """ + with self._lock: + if self.flag_set_exist(flag_set): + self.sets_feature_flag_map[flag_set].add(feature_flag) + + def remove_feature_flag_to_flag_set(self, flag_set, feature_flag): + """ + Remove a feature flag from existing flag set + :param flag_set: set name + :type flag_set: str + :param feature_flag: feature flag name + :type feature_flag: str + """ + with self._lock: + if self.flag_set_exist(flag_set): + self.sets_feature_flag_map[flag_set].remove(feature_flag) + + def update_flag_set(self, flag_sets, feature_flag_name, should_filter): + if flag_sets is not None: + for flag_set in flag_sets: + if not self.flag_set_exist(flag_set): + if should_filter: + continue + self._add_flag_set(flag_set) + self.add_feature_flag_to_flag_set(flag_set, feature_flag_name) + + def remove_flag_set(self, flag_sets, feature_flag_name, should_filter): + if flag_sets is not None: + for flag_set in flag_sets: + self.remove_feature_flag_to_flag_set(flag_set, feature_flag_name) + if self.flag_set_exist(flag_set) and len(self.get_flag_set(flag_set)) == 0 and not should_filter: + self._remove_flag_set(flag_set) + +class InMemoryRuleBasedSegmentStorage(RuleBasedSegmentsStorage): + """InMemory implementation of a feature flag storage base.""" + + def __init__(self, internal_event_queue): + """Constructor.""" + self._lock = threading.RLock() + self._rule_based_segments = {} + self._change_number = -1 + self._internal_event_queue = internal_event_queue + + def clear(self): + """ + Clear storage """ with self._lock: - return self._splits.get(split_name) + self._rule_based_segments = {} + self._change_number = -1 - def fetch_many(self, split_names): + def get(self, segment_name): """ - Retrieve splits. + Retrieve a rule based segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str - :param split_names: Names of the features to fetch. - :type split_name: list(str) + :rtype: splitio.models.rule_based_segments.RuleBasedSegment + """ + with self._lock: + return self._rule_based_segments.get(segment_name) - :return: A dict with split objects parsed from queue. - :rtype: dict(split_name, splitio.models.splits.Split) + def update(self, to_add, to_delete, new_change_number): """ - return {split_name: self.get(split_name) for split_name in split_names} + Update rule based segment. - def put(self, split): + :param to_add: List of rule based segment. to add + :type to_add: list[splitio.models.rule_based_segments.RuleBasedSegment] + :param to_delete: List of rule based segment. to delete + :type to_delete: list[splitio.models.rule_based_segments.RuleBasedSegment] + :param new_change_number: New change number. + :type new_change_number: int """ - Store a split. + [self._put(add_segment) for add_segment in to_add] + [self._remove(delete_segment) for delete_segment in to_delete] + self._set_change_number(new_change_number) + if len(to_add) > 0 or len(to_delete) > 0: + self._internal_event_queue.put( + SdkInternalEventNotification( + SdkInternalEvent.RB_SEGMENTS_UPDATED, + EventsMetadata(SdkEventType.SEGMENTS_UPDATE, {}))) + + def _put(self, rule_based_segment): + """ + Store a rule based segment. - :param split: Split object. - :type split: splitio.models.split.Split + :param rule_based_segment: RuleBasedSegment object. + :type rule_based_segment: splitio.models.rule_based_segments.RuleBasedSegment """ with self._lock: - if split.name in self._splits: - self._decrease_traffic_type_count(self._splits[split.name].traffic_type_name) - self._splits[split.name] = split - self._increase_traffic_type_count(split.traffic_type_name) + self._rule_based_segments[rule_based_segment.name] = rule_based_segment - def remove(self, split_name): + def _remove(self, segment_name): """ - Remove a split from storage. + Remove a rule based segment. - :param split_name: Name of the feature to remove. - :type split_name: str + :param segment_name: Name of the rule based segment to remove. + :type segment_name: str - :return: True if the split was found and removed. False otherwise. + :return: True if the rule based segment was found and removed. False otherwise. :rtype: bool """ with self._lock: - split = self._splits.get(split_name) - if not split: - _LOGGER.warning("Tried to delete nonexistant split %s. Skipping", split_name) + rule_based_segment = self._rule_based_segments.get(segment_name) + if not rule_based_segment: + _LOGGER.warning("Tried to delete nonexistant Rule based segment %s. Skipping", segment_name) return False - self._splits.pop(split_name) - self._decrease_traffic_type_count(split.traffic_type_name) + self._rule_based_segments.pop(segment_name) return True def get_change_number(self): """ - Retrieve latest split change number. + Retrieve latest rule based segment change number. :rtype: int """ with self._lock: return self._change_number - def set_change_number(self, new_change_number): + def _set_change_number(self, new_change_number): """ Set the latest change number. @@ -99,29 +208,248 @@ def set_change_number(self, new_change_number): with self._lock: self._change_number = new_change_number - def get_split_names(self): + def get_segment_names(self): + """ + Retrieve a list of all rule based segments names. + + :return: List of segment names. + :rtype: list(str) + """ + with self._lock: + return list(self._rule_based_segments.keys()) + + def get_large_segment_names(self): """ - Retrieve a list of all split names. + Retrieve a list of all excluded large segments names. - :return: List of split names. + :return: List of segment names. :rtype: list(str) """ + pass + + def contains(self, segment_names): + """ + Return whether the segment exists in storage + + :param segment_names: rule based segment name + :type segment_names: str + + :return: True if the segment exists. False otherwise. + :rtype: bool + """ with self._lock: - return list(self._splits.keys()) + return set(segment_names).issubset(self._rule_based_segments.keys()) + + def fetch_many(self, segment_names): + return {rb_segment_name: self.get(rb_segment_name) for rb_segment_name in segment_names} + +class InMemoryRuleBasedSegmentStorageAsync(RuleBasedSegmentsStorage): + """InMemory implementation of a feature flag storage base.""" + def __init__(self, internal_event_queue): + """Constructor.""" + self._lock = asyncio.Lock() + self._rule_based_segments = {} + self._change_number = -1 + self._internal_event_queue = internal_event_queue + + async def clear(self): + """ + Clear storage + """ + async with self._lock: + self._rule_based_segments = {} + self._change_number = -1 + + async def get(self, segment_name): + """ + Retrieve a rule based segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str + + :rtype: splitio.models.rule_based_segments.RuleBasedSegment + """ + async with self._lock: + return self._rule_based_segments.get(segment_name) + + async def update(self, to_add, to_delete, new_change_number): + """ + Update rule based segment. + + :param to_add: List of rule based segment. to add + :type to_add: list[splitio.models.rule_based_segments.RuleBasedSegment] + :param to_delete: List of rule based segment. to delete + :type to_delete: list[splitio.models.rule_based_segments.RuleBasedSegment] + :param new_change_number: New change number. + :type new_change_number: int + """ + [await self._put(add_segment) for add_segment in to_add] + [await self._remove(delete_segment) for delete_segment in to_delete] + await self._set_change_number(new_change_number) + if len(to_add) > 0 or len(to_delete) > 0: + await self._internal_event_queue.put( + SdkInternalEventNotification( + SdkInternalEvent.RB_SEGMENTS_UPDATED, + EventsMetadata(SdkEventType.SEGMENTS_UPDATE, {}))) + + async def _put(self, rule_based_segment): + """ + Store a rule based segment. + + :param rule_based_segment: RuleBasedSegment object. + :type rule_based_segment: splitio.models.rule_based_segments.RuleBasedSegment + """ + async with self._lock: + self._rule_based_segments[rule_based_segment.name] = rule_based_segment + + async def _remove(self, segment_name): + """ + Remove a rule based segment. + + :param segment_name: Name of the rule based segment to remove. + :type segment_name: str + + :return: True if the rule based segment was found and removed. False otherwise. + :rtype: bool + """ + async with self._lock: + rule_based_segment = self._rule_based_segments.get(segment_name) + if not rule_based_segment: + _LOGGER.warning("Tried to delete nonexistant Rule based segment %s. Skipping", segment_name) + return False + + self._rule_based_segments.pop(segment_name) + return True + + async def get_change_number(self): + """ + Retrieve latest rule based segment change number. + + :rtype: int + """ + async with self._lock: + return self._change_number + + async def _set_change_number(self, new_change_number): + """ + Set the latest change number. + + :param new_change_number: New change number. + :type new_change_number: int + """ + async with self._lock: + self._change_number = new_change_number + + async def get_segment_names(self): + """ + Retrieve a list of all excluded segments names. + + :return: List of segment names. + :rtype: list(str) + """ + async with self._lock: + return list(self._rule_based_segments.keys()) + + async def get_large_segment_names(self): + """ + Retrieve a list of all excluded large segments names. + + :return: List of segment names. + :rtype: list(str) + """ + pass + + async def contains(self, segment_names): + """ + Return whether the segment exists in storage + + :param segment_names: rule based segment name + :type segment_names: str + + :return: True if the segment exists. False otherwise. + :rtype: bool + """ + async with self._lock: + return set(segment_names).issubset(self._rule_based_segments.keys()) + + async def fetch_many(self, segment_names): + return {rb_segment_name: await self.get(rb_segment_name) for rb_segment_name in segment_names} + +class InMemorySplitStorageBase(SplitStorage): + """InMemory implementation of a feature flag storage base.""" + + def get(self, feature_flag_name): + """ + Retrieve a feature flag. + + :param feature_flag_name: Name of the feature to fetch. + :type feature_flag_name: str + + :rtype: splitio.models.splits.Split + """ + pass + + def fetch_many(self, feature_flag_names): + """ + Retrieve feature flags. + + :param feature_flag_names: Names of the features to fetch. + :type feature_flag_name: list(str) + + :return: A dict with feature flag objects parsed from queue. + :rtype: dict(feature_flag_name, splitio.models.splits.Split) + """ + pass + + def update(self, to_add, to_delete, new_change_number): + """ + Update feature flag storage. + :param to_add: List of feature flags to add + :type to_add: list[splitio.models.splits.Split] + :param to_delete: List of feature flags to delete + :type to_delete: list[str] + :param new_change_number: New change number. + :type new_change_number: int + """ + pass + + def get_change_number(self): + """ + Retrieve latest feature flag change number. + + :rtype: int + """ + pass + + def get_split_names(self): + """ + Retrieve a list of all feature flag names. + + :return: List of feature flag names. + :rtype: list(str) + """ + pass def get_all_splits(self): """ - Return all the splits. + Return all the feature flags. - :return: List of all the splits. + :return: List of all the feature flags. :rtype: list """ - with self._lock: - return list(self._splits.values()) + pass + + def get_splits_count(self): + """ + Return feature flags count. + + :rtype: int + """ + pass def is_valid_traffic_type(self, traffic_type_name): """ - Return whether the traffic type exists in at least one split in cache. + Return whether the traffic type exists in at least one feature flag in cache. :param traffic_type_name: Traffic type to validate. :type traffic_type_name: str @@ -129,28 +457,20 @@ def is_valid_traffic_type(self, traffic_type_name): :return: True if the traffic type is valid. False otherwise. :rtype: bool """ - with self._lock: - return traffic_type_name in self._traffic_types + pass - def kill_locally(self, split_name, default_treatment, change_number): + def kill_locally(self, feature_flag_name, default_treatment, change_number): """ - Local kill for split + Local kill for feature flag - :param split_name: name of the split to perform kill - :type split_name: str + :param feature_flag_name: name of the feature flag to perform kill + :type feature_flag_name: str :param default_treatment: name of the default treatment to return :type default_treatment: str :param change_number: change_number :type change_number: int """ - with self._lock: - if self.get_change_number() > change_number: - return - split = self._splits.get(split_name) - if not split: - return - split.local_kill(default_treatment, change_number) - self.put(split) + pass def _increase_traffic_type_count(self, traffic_type_name): """ @@ -171,118 +491,797 @@ def _decrease_traffic_type_count(self, traffic_type_name): self._traffic_types.subtract([traffic_type_name]) self._traffic_types += Counter() +class InMemorySplitStorage(InMemorySplitStorageBase): + """InMemory implementation of a feature flag storage.""" -class InMemorySegmentStorage(SegmentStorage): - """In-memory implementation of a segment storage.""" - - def __init__(self): + def __init__(self, internal_event_queue, flag_sets=[]): """Constructor.""" - self._segments = {} - self._change_numbers = {} self._lock = threading.RLock() + self._feature_flags = {} + self._change_number = -1 + self._traffic_types = Counter() + self.flag_set = FlagSets(flag_sets) + self.flag_set_filter = FlagSetsFilter(flag_sets) + self._internal_event_queue = internal_event_queue - def get(self, segment_name): + def clear(self): """ - Retrieve a segment. + Clear storage + """ + with self._lock: + self._feature_flags = {} + self._change_number = -1 + self._traffic_types = Counter() + self.flag_set = FlagSets(self.flag_set_filter.flag_sets) + + def get(self, feature_flag_name): + """ + Retrieve a feature flag. - :param segment_name: Name of the segment to fetch. - :type segment_name: str + :param feature_flag_name: Name of the feature to fetch. + :type feature_flag_name: str - :rtype: str + :rtype: splitio.models.splits.Split """ with self._lock: - fetched = self._segments.get(segment_name) - if fetched is None: - _LOGGER.warning( - "Tried to retrieve nonexistant segment %s. Skipping", - segment_name - ) - return fetched + return self._feature_flags.get(feature_flag_name) - def put(self, segment): + def fetch_many(self, feature_flag_names): """ - Store a segment. + Retrieve feature flags. - :param segment: Segment to store. - :type segment: splitio.models.segment.Segment + :param feature_flag_names: Names of the features to fetch. + :type feature_flag_names: list(str) + + :return: A dict with feature flag objects parsed from queue. + :rtype: dict(feature_flag_name, splitio.models.splits.Split) """ - with self._lock: - self._segments[segment.name] = segment + return {feature_flag_name: self.get(feature_flag_name) for feature_flag_name in feature_flag_names} - def update(self, segment_name, to_add, to_remove, change_number=None): + def update(self, to_add, to_delete, new_change_number): + """ + Update feature flag storage. + :param to_add: List of feature flags to add + :type to_add: list[splitio.models.splits.Split] + :param to_delete: List of feature flags to delete + :type to_delete: list[str] + :param new_change_number: New change number. + :type new_change_number: int """ - Update a split. Create it if it doesn't exist. + [self._put(add_feature_flag) for add_feature_flag in to_add] + [self._remove(delete_feature_flag) for delete_feature_flag in to_delete] + self._set_change_number(new_change_number) + to_notify = [] + [to_notify.append(feature.name) for feature in to_add] + to_notify.extend(to_delete) + if len(to_notify) > 0: + self._internal_event_queue.put( + SdkInternalEventNotification( + SdkInternalEvent.FLAGS_UPDATED, + EventsMetadata(SdkEventType.FLAG_UPDATE, set(to_notify)))) + + def _put(self, feature_flag): + """ + Store a feature flag. - :param segment_name: Name of the segment to update. - :type segment_name: str - :param to_add: Set of members to add to the segment. - :type to_add: set - :param to_remove: List of members to remove from the segment. - :type to_remove: Set + :param feature_flag: Split object. + :type feature_flag: splitio.models.split.Split """ with self._lock: - if segment_name not in self._segments: - self._segments[segment_name] = Segment(segment_name, to_add, change_number) - return + if feature_flag.name in self._feature_flags: + self._remove_from_flag_sets(self._feature_flags[feature_flag.name]) + self._decrease_traffic_type_count(self._feature_flags[feature_flag.name].traffic_type_name) + self._feature_flags[feature_flag.name] = feature_flag + self._increase_traffic_type_count(feature_flag.traffic_type_name) + self.flag_set.update_flag_set(feature_flag.sets, feature_flag.name, self.flag_set_filter.should_filter) + + def _remove(self, feature_flag_name): + """ + Remove a feature flag from storage. - self._segments[segment_name].update(to_add, to_remove) - if change_number is not None: - self._segments[segment_name].change_number = change_number + :param feature_flag_name: Name of the feature to remove. + :type feature_flag_name: str - def get_change_number(self, segment_name): + :return: True if the feature_flag was found and removed. False otherwise. + :rtype: bool """ - Retrieve latest change number for a segment. + with self._lock: + feature_flag = self._feature_flags.get(feature_flag_name) + if not feature_flag: + _LOGGER.warning("Tried to delete nonexistant feature flag %s. Skipping", feature_flag_name) + return False - :param segment_name: Name of the segment. - :type segment_name: str + self._feature_flags.pop(feature_flag_name) + self._decrease_traffic_type_count(feature_flag.traffic_type_name) + self._remove_from_flag_sets(feature_flag) + return True + + def _remove_from_flag_sets(self, feature_flag): + """ + Remove flag sets associated to a feature flag + :param feature_flag: feature flag object + :type feature_flag: splitio.models.splits.Split + """ + self.flag_set.remove_flag_set(feature_flag.sets, feature_flag.name, self.flag_set_filter.should_filter) + + def get_feature_flags_by_sets(self, sets): + """ + Get list of feature flag names associated to a set, if it does not exist will return empty list + :param set: flag set + :type set: str + :return: list of feature flag names + :rtype: list + """ + with self._lock: + sets_to_fetch = [] + for flag_set in sets: + if not self.flag_set.flag_set_exist(flag_set): + _LOGGER.warning("Flag set %s is not part of the configured flag set list, ignoring it." % (flag_set)) + continue + sets_to_fetch.append(flag_set) + + to_return = set() + [to_return.update(self.flag_set.get_flag_set(flag_set)) for flag_set in sets_to_fetch] + return list(to_return) + + def get_change_number(self): + """ + Retrieve latest feature flag change number. :rtype: int """ with self._lock: - if segment_name not in self._segments: - return None - return self._segments[segment_name].change_number + return self._change_number - def set_change_number(self, segment_name, new_change_number): + def _set_change_number(self, new_change_number): """ Set the latest change number. - :param segment_name: Name of the segment. - :type segment_name: str :param new_change_number: New change number. :type new_change_number: int """ with self._lock: - if segment_name not in self._segments: - return - self._segments[segment_name].change_number = new_change_number + self._change_number = new_change_number - def segment_contains(self, segment_name, key): + def get_split_names(self): """ - Check whether a specific key belongs to a segment in storage. - - :param segment_name: Name of the segment to search in. - :type segment_name: str - :param key: Key to search for. - :type key: str + Retrieve a list of all feature flag names. - :return: True if the segment contains the key. False otherwise. - :rtype: bool + :return: List of feature flag names. + :rtype: list(str) """ with self._lock: - if segment_name not in self._segments: - _LOGGER.warning( - "Tried to query members for nonexistant segment %s. Returning False", - segment_name + return list(self._feature_flags.keys()) + + def get_all_splits(self): + """ + Return all the feature flags. + + :return: List of all the feature flags. + :rtype: list + """ + with self._lock: + return list(self._feature_flags.values()) + + def get_splits_count(self): + """ + Return feature flags count. + + :rtype: int + """ + with self._lock: + return len(self._feature_flags) + + def is_valid_traffic_type(self, traffic_type_name): + """ + Return whether the traffic type exists in at least one feature flag in cache. + + :param traffic_type_name: Traffic type to validate. + :type traffic_type_name: str + + :return: True if the traffic type is valid. False otherwise. + :rtype: bool + """ + with self._lock: + return traffic_type_name in self._traffic_types + + def kill_locally(self, feature_flag_name, default_treatment, change_number): + """ + Local kill for feature flag + + :param feature_flag_name: name of the feature flag to perform kill + :type feature_flag_name: str + :param default_treatment: name of the default treatment to return + :type default_treatment: str + :param change_number: change_number + :type change_number: int + """ + with self._lock: + if self.get_change_number() > change_number: + return + feature_flag = self._feature_flags.get(feature_flag_name) + if not feature_flag: + return + feature_flag.local_kill(default_treatment, change_number) + self._put(feature_flag) + self._internal_event_queue.put( + SdkInternalEventNotification( + SdkInternalEvent.FLAG_KILLED_NOTIFICATION, + EventsMetadata(SdkEventType.FLAG_UPDATE, {feature_flag_name}))) + + def is_flag_set_exist(self, flag_set): + """ + Return whether a flag set exists in at least one feature flag in cache. + :param flag_set: Flag set to validate. + :type flag_set: str + + :return: True if the flag_set exist. False otherwise. + :rtype: bool + """ + return self.flag_set.flag_set_exist(flag_set) + +class InMemorySplitStorageAsync(InMemorySplitStorageBase): + """InMemory implementation of a feature flag async storage.""" + + def __init__(self, internal_event_queue, flag_sets=[]): + """Constructor.""" + self._lock = asyncio.Lock() + self._feature_flags = {} + self._change_number = -1 + self._traffic_types = Counter() + self.flag_set = FlagSets(flag_sets) + self.flag_set_filter = FlagSetsFilter(flag_sets) + self._internal_event_queue = internal_event_queue + + async def clear(self): + """ + Clear storage + """ + async with self._lock: + self._feature_flags = {} + self._change_number = -1 + self._traffic_types = Counter() + self.flag_set = FlagSets(self.flag_set_filter.flag_sets) + + async def get(self, feature_flag_name): + """ + Retrieve a feature flag. + + :param feature_flag_name: Name of the feature to fetch. + :type feature_flag_name: str + + :rtype: splitio.models.splits.Split + """ + async with self._lock: + return self._feature_flags.get(feature_flag_name) + + async def fetch_many(self, feature_flag_names): + """ + Retrieve feature flags. + + :param feature_flag_names: Names of the features to fetch. + :type feature_flag_name: list(str) + + :return: A dict with feature flag objects parsed from queue. + :rtype: dict(feature_flag_name, splitio.models.splits.Split) + """ + return {feature_flag_name: await self.get(feature_flag_name) for feature_flag_name in feature_flag_names} + + async def update(self, to_add, to_delete, new_change_number): + """ + Update feature flag storage. + :param to_add: List of feature flags to add + :type to_add: list[splitio.models.splits.Split] + :param to_delete: List of feature flags to delete + :type to_delete: list[str] + :param new_change_number: New change number. + :type new_change_number: int + """ + [await self._put(add_feature_flag) for add_feature_flag in to_add] + [await self._remove(delete_feature_flag) for delete_feature_flag in to_delete] + await self._set_change_number(new_change_number) + to_notify = [] + [to_notify.append(feature.name) for feature in to_add] + to_notify.extend(to_delete) + if len(to_notify) > 0: + await self._internal_event_queue.put( + SdkInternalEventNotification( + SdkInternalEvent.FLAGS_UPDATED, + EventsMetadata(SdkEventType.FLAG_UPDATE, set(to_notify)))) + + async def _put(self, feature_flag): + """ + Store a feature flag. + + :param feature flag: Split object. + :type feature flag: splitio.models.split.Split + """ + async with self._lock: + if feature_flag.name in self._feature_flags: + await self._remove_from_flag_sets(self._feature_flags[feature_flag.name]) + self._decrease_traffic_type_count(self._feature_flags[feature_flag.name].traffic_type_name) + self._feature_flags[feature_flag.name] = feature_flag + self._increase_traffic_type_count(feature_flag.traffic_type_name) + self.flag_set.update_flag_set(feature_flag.sets, feature_flag.name, self.flag_set_filter.should_filter) + + async def _remove(self, feature_flag_name): + """ + Remove a feature flag from storage. + + :param feature_flag_name: Name of the feature to remove. + :type feature_flag_name: str + + :return: True if the feature flag was found and removed. False otherwise. + :rtype: bool + """ + async with self._lock: + feature_flag = self._feature_flags.get(feature_flag_name) + if not feature_flag: + _LOGGER.warning("Tried to delete nonexistant feature flag %s. Skipping", feature_flag_name) + return False + + self._feature_flags.pop(feature_flag_name) + self._decrease_traffic_type_count(feature_flag.traffic_type_name) + await self._remove_from_flag_sets(feature_flag) + return True + + async def _remove_from_flag_sets(self, feature_flag): + """ + Remove flag sets associated to a feature flag + :param feature_flag: feature flag object + :type feature_flag: splitio.models.splits.Split + """ + self.flag_set.remove_flag_set(feature_flag.sets, feature_flag.name, self.flag_set_filter.should_filter) + + async def get_feature_flags_by_sets(self, sets): + """ + Get list of feature flag names associated to a set, if it does not exist will return empty list + :param set: flag set + :type set: str + :return: list of feature flag names + :rtype: list + """ + async with self._lock: + sets_to_fetch = [] + for flag_set in sets: + if not self.flag_set.flag_set_exist(flag_set): + _LOGGER.warning("Flag set %s is not part of the configured flag set list, ignoring it." % (flag_set)) + continue + sets_to_fetch.append(flag_set) + + to_return = set() + [to_return.update(self.flag_set.get_flag_set(flag_set)) for flag_set in sets_to_fetch] + return list(to_return) + + async def get_change_number(self): + """ + Retrieve latest feature flag change number. + + :rtype: int + """ + async with self._lock: + return self._change_number + + async def _set_change_number(self, new_change_number): + """ + Set the latest change number. + + :param new_change_number: New change number. + :type new_change_number: int + """ + async with self._lock: + self._change_number = new_change_number + + async def get_split_names(self): + """ + Retrieve a list of all feature flag names. + + :return: List of feature flag names. + :rtype: list(str) + """ + async with self._lock: + return list(self._feature_flags.keys()) + + async def get_all_splits(self): + """ + Return all the feature flags. + + :return: List of all the feature flags. + :rtype: list + """ + async with self._lock: + return list(self._feature_flags.values()) + + async def get_splits_count(self): + """ + Return feature flags count. + + :rtype: int + """ + async with self._lock: + return len(self._feature_flags) + + async def is_valid_traffic_type(self, traffic_type_name): + """ + Return whether the traffic type exists in at least one feature flag in cache. + + :param traffic_type_name: Traffic type to validate. + :type traffic_type_name: str + + :return: True if the traffic type is valid. False otherwise. + :rtype: bool + """ + async with self._lock: + return traffic_type_name in self._traffic_types + + async def kill_locally(self, feature_flag_name, default_treatment, change_number): + """ + Local kill for feature flag + + :param feature_flag_name: name of the feature flag to perform kill + :type feature_flag_name: str + :param default_treatment: name of the default treatment to return + :type default_treatment: str + :param change_number: change_number + :type change_number: int + """ + if await self.get_change_number() > change_number: + return + async with self._lock: + feature_flag = self._feature_flags.get(feature_flag_name) + if not feature_flag: + return + feature_flag.local_kill(default_treatment, change_number) + await self._put(feature_flag) + await self._internal_event_queue.put( + SdkInternalEventNotification( + SdkInternalEvent.FLAG_KILLED_NOTIFICATION, + EventsMetadata(SdkEventType.FLAG_UPDATE, {feature_flag_name}))) + + + async def get_segment_names(self): + """ + Return a set of all segments referenced by feature flags in storage. + + :return: Set of all segment names. + :rtype: set(string) + """ + return set([name for spl in await self.get_all_splits() for name in spl.get_segment_names()]) + + async def is_flag_set_exist(self, flag_set): + """ + Return whether a flag set exists in at least one feature flag in cache. + :param flag_set: Flag set to validate. + :type flag_set: str + :return: True if the flag_set exist. False otherwise. + :rtype: bool + """ + return self.flag_set.flag_set_exist(flag_set) + +class InMemorySegmentStorage(SegmentStorage): + """In-memory implementation of a segment storage.""" + + def __init__(self, internal_event_queue): + """Constructor.""" + self._segments = {} + self._change_numbers = {} + self._lock = threading.RLock() + self._internal_event_queue = internal_event_queue + + def get(self, segment_name): + """ + Retrieve a segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str + + :rtype: str + """ + with self._lock: + fetched = self._segments.get(segment_name) + if fetched is None: + _LOGGER.debug( + "Tried to retrieve nonexistant segment %s. Skipping", + segment_name + ) + return fetched + + def put(self, segment): + """ + Store a segment. + + :param segment: Segment to store. + :type segment: splitio.models.segment.Segment + """ + with self._lock: + self._segments[segment.name] = segment + + self._internal_event_queue.put( + SdkInternalEventNotification( + SdkInternalEvent.SEGMENTS_UPDATED, + EventsMetadata(SdkEventType.SEGMENTS_UPDATE, {}))) + + def update(self, segment_name, to_add, to_remove, change_number=None): + """ + Update a segment. Create it if it doesn't exist. + + :param segment_name: Name of the segment to update. + :type segment_name: str + :param to_add: Set of members to add to the segment. + :type to_add: set + :param to_remove: List of members to remove from the segment. + :type to_remove: Set + """ + with self._lock: + if segment_name not in self._segments: + self._segments[segment_name] = Segment(segment_name, to_add, change_number) + return + + self._segments[segment_name].update(to_add, to_remove) + if change_number is not None: + self._segments[segment_name].change_number = change_number + + if len(to_add) > 0 or len(to_remove) >0: + self._internal_event_queue.put( + SdkInternalEventNotification( + SdkInternalEvent.SEGMENTS_UPDATED, + EventsMetadata(SdkEventType.SEGMENTS_UPDATE, {}))) + + def get_change_number(self, segment_name): + """ + Retrieve latest change number for a segment. + + :param segment_name: Name of the segment. + :type segment_name: str + + :rtype: int + """ + with self._lock: + if segment_name not in self._segments: + return None + + return self._segments[segment_name].change_number + + def set_change_number(self, segment_name, new_change_number): + """ + Set the latest change number. + + :param segment_name: Name of the segment. + :type segment_name: str + :param new_change_number: New change number. + :type new_change_number: int + """ + with self._lock: + if segment_name not in self._segments: + return + self._segments[segment_name].change_number = new_change_number + + def segment_contains(self, segment_name, key): + """ + Check whether a specific key belongs to a segment in storage. + + :param segment_name: Name of the segment to search in. + :type segment_name: str + :param key: Key to search for. + :type key: str + + :return: True if the segment contains the key. False otherwise. + :rtype: bool + """ + with self._lock: + if segment_name not in self._segments: + _LOGGER.warning( + "Tried to query members for nonexistant segment %s. Returning False", + segment_name + ) + return False + + return self._segments[segment_name].contains(key) + + def get_segments_count(self): + """ + Retrieve segments count. + + :rtype: int + """ + with self._lock: + return len(self._segments) + + def get_segments_keys_count(self): + """ + Retrieve segments keys count. + + :rtype: int + """ + total_count = 0 + with self._lock: + for segment in self._segments: + total_count += len(self._segments[segment]._keys) + return total_count + + +class InMemorySegmentStorageAsync(SegmentStorage): + """In-memory implementation of a segment async storage.""" + + def __init__(self, internal_event_queue): + """Constructor.""" + self._segments = {} + self._change_numbers = {} + self._lock = asyncio.Lock() + self._internal_event_queue = internal_event_queue + + async def get(self, segment_name): + """ + Retrieve a segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str + + :rtype: str + """ + async with self._lock: + fetched = self._segments.get(segment_name) + if fetched is None: + _LOGGER.debug( + "Tried to retrieve nonexistant segment %s. Skipping", + segment_name + ) + return fetched + + async def put(self, segment): + """ + Store a segment. + + :param segment: Segment to store. + :type segment: splitio.models.segment.Segment + """ + async with self._lock: + self._segments[segment.name] = segment + await self._internal_event_queue.put( + SdkInternalEventNotification( + SdkInternalEvent.SEGMENTS_UPDATED, + EventsMetadata(SdkEventType.SEGMENTS_UPDATE, {}))) + + + async def update(self, segment_name, to_add, to_remove, change_number=None): + """ + Update a segment. Create it if it doesn't exist. + + :param segment_name: Name of the segment to update. + :type segment_name: str + :param to_add: Set of members to add to the segment. + :type to_add: set + :param to_remove: List of members to remove from the segment. + :type to_remove: Set + """ + async with self._lock: + if segment_name not in self._segments: + self._segments[segment_name] = Segment(segment_name, to_add, change_number) + return + + self._segments[segment_name].update(to_add, to_remove) + if change_number is not None: + self._segments[segment_name].change_number = change_number + if len(to_add) > 0 or len(to_remove) >0: + await self._internal_event_queue.put( + SdkInternalEventNotification( + SdkInternalEvent.SEGMENTS_UPDATED, + EventsMetadata(SdkEventType.SEGMENTS_UPDATE, {}))) + + + async def get_change_number(self, segment_name): + """ + Retrieve latest change number for a segment. + + :param segment_name: Name of the segment. + :type segment_name: str + + :rtype: int + """ + async with self._lock: + if segment_name not in self._segments: + return None + + return self._segments[segment_name].change_number + + async def set_change_number(self, segment_name, new_change_number): + """ + Set the latest change number. + + :param segment_name: Name of the segment. + :type segment_name: str + :param new_change_number: New change number. + :type new_change_number: int + """ + async with self._lock: + if segment_name not in self._segments: + return + self._segments[segment_name].change_number = new_change_number + + async def segment_contains(self, segment_name, key): + """ + Check whether a specific key belongs to a segment in storage. + + :param segment_name: Name of the segment to search in. + :type segment_name: str + :param key: Key to search for. + :type key: str + + :return: True if the segment contains the key. False otherwise. + :rtype: bool + """ + async with self._lock: + if segment_name not in self._segments: + _LOGGER.warning( + "Tried to query members for nonexistant segment %s. Returning False", + segment_name ) return False + return self._segments[segment_name].contains(key) + async def get_segments_count(self): + """ + Retrieve segments count. + + :rtype: int + """ + async with self._lock: + return len(self._segments) + + async def get_segments_keys_count(self): + """ + Retrieve segments keys count. + + :rtype: int + """ + total_count = 0 + async with self._lock: + for segment in self._segments: + total_count += len(self._segments[segment]._keys) + return total_count + + +class InMemoryImpressionStorageBase(ImpressionStorage): + """In memory implementation of an impressions base storage.""" + + def set_queue_full_hook(self, hook): + """ + Set a hook to be called when the queue is full. + + :param h: Hook to be called when the queue is full + """ + if callable(hook): + self._queue_full_hook = hook + + def put(self, impressions): + """ + Put one or more impressions in storage. + + :param impressions: List of one or more impressions to store. + :type impressions: list + """ + pass + + def pop_many(self, count): + """ + Pop the oldest N impressions from storage. + + :param count: Number of impressions to pop. + :type count: int + """ + pass + + def clear(self): + """ + Clear data. + """ + pass -class InMemoryImpressionStorage(ImpressionStorage): +class InMemoryImpressionStorage(InMemoryImpressionStorageBase): """In memory implementation of an impressions storage.""" - def __init__(self, queue_size): + def __init__(self, queue_size, telemetry_runtime_producer): """ Construct an instance. @@ -292,15 +1291,7 @@ def __init__(self, queue_size): self._impressions = queue.Queue(maxsize=queue_size) self._lock = threading.Lock() self._queue_full_hook = None - - def set_queue_full_hook(self, hook): - """ - Set a hook to be called when the queue is full. - - :param h: Hook to be called when the queue is full - """ - if callable(hook): - self._queue_full_hook = hook + self._telemetry_runtime_producer = telemetry_runtime_producer def put(self, impressions): """ @@ -309,12 +1300,18 @@ def put(self, impressions): :param impressions: List of one or more impressions to store. :type impressions: list """ + impressions_stored = 0 try: with self._lock: for impression in impressions: self._impressions.put(impression, False) + impressions_stored += 1 + self._telemetry_runtime_producer.record_impression_stats(CounterConstants.IMPRESSIONS_QUEUED, len(impressions)) return True + except queue.Full: + self._telemetry_runtime_producer.record_impression_stats(CounterConstants.IMPRESSIONS_DROPPED, len(impressions) - impressions_stored) + self._telemetry_runtime_producer.record_impression_stats(CounterConstants.IMPRESSIONS_QUEUED, impressions_stored) if self._queue_full_hook is not None and callable(self._queue_full_hook): self._queue_full_hook() _LOGGER.warning( @@ -345,25 +1342,77 @@ def clear(self): self._impressions = queue.Queue(maxsize=self._queue_size) -class InMemoryEventStorage(EventStorage): - """ - In memory storage for events. - - Supports adding and popping events. - """ +class InMemoryImpressionStorageAsync(InMemoryImpressionStorageBase): + """In memory implementation of an impressions async storage.""" - def __init__(self, eventsQueueSize): + def __init__(self, queue_size, telemetry_runtime_producer): """ Construct an instance. :param eventsQueueSize: How many events to queue before forcing a submission """ - self._queue_size = eventsQueueSize - self._lock = threading.Lock() - self._events = queue.Queue(maxsize=eventsQueueSize) + self._queue_size = queue_size + self._impressions = asyncio.Queue(maxsize=queue_size) + self._lock = asyncio.Lock() self._queue_full_hook = None - self._size = 0 + self._telemetry_runtime_producer = telemetry_runtime_producer + + async def put(self, impressions): + """ + Put one or more impressions in storage. + + :param impressions: List of one or more impressions to store. + :type impressions: list + """ + impressions_stored = 0 + try: + async with self._lock: + for impression in impressions: + if self._impressions.qsize() == self._queue_size: + raise asyncio.QueueFull + await self._impressions.put(impression) + impressions_stored += 1 + await self._telemetry_runtime_producer.record_impression_stats(CounterConstants.IMPRESSIONS_QUEUED, len(impressions)) + return True + + except asyncio.QueueFull: + await self._telemetry_runtime_producer.record_impression_stats(CounterConstants.IMPRESSIONS_DROPPED, len(impressions) - impressions_stored) + await self._telemetry_runtime_producer.record_impression_stats(CounterConstants.IMPRESSIONS_QUEUED, impressions_stored) + if self._queue_full_hook is not None and callable(self._queue_full_hook): + await self._queue_full_hook() + _LOGGER.warning( + 'Impression queue is full, failing to add more impressions. \n' + 'Consider increasing parameter `impressionsQueueSize` in configuration' + ) + return False + + async def pop_many(self, count): + """ + Pop the oldest N impressions from storage. + + :param count: Number of impressions to pop. + :type count: int + """ + impressions = [] + async with self._lock: + while not self._impressions.empty() and count > 0: + impressions.append(await self._impressions.get()) + count -= 1 + return impressions + + async def clear(self): + """ + Clear data. + """ + async with self._lock: + self._impressions = asyncio.Queue(maxsize=self._queue_size) + +class InMemoryEventStorageBase(EventStorage): + """ + In memory storage base class for events. + Supports adding and popping events. + """ def set_queue_full_hook(self, hook): """ Set a hook to be called when the queue is full. @@ -379,6 +1428,50 @@ def put(self, events): :param event: Event to be added in the storage """ + pass + + def pop_many(self, count): + """ + Pop multiple items from the storage. + + :param count: number of items to be retrieved and removed from the queue. + """ + pass + + def clear(self): + """ + Clear data. + """ + pass + + +class InMemoryEventStorage(InMemoryEventStorageBase): + """ + In memory storage for events. + + Supports adding and popping events. + """ + + def __init__(self, eventsQueueSize, telemetry_runtime_producer): + """ + Construct an instance. + + :param eventsQueueSize: How many events to queue before forcing a submission + """ + self._queue_size = eventsQueueSize + self._lock = threading.Lock() + self._events = queue.Queue(maxsize=eventsQueueSize) + self._queue_full_hook = None + self._size = 0 + self._telemetry_runtime_producer = telemetry_runtime_producer + + def put(self, events): + """ + Add an event to storage. + + :param event: Event to be added in the storage + """ + events_stored = 0 try: with self._lock: for event in events: @@ -389,8 +1482,13 @@ def put(self, events): return False self._events.put(event.event, False) + events_stored += 1 + self._telemetry_runtime_producer.record_event_stats(CounterConstants.EVENTS_QUEUED, len(events)) return True + except queue.Full: + self._telemetry_runtime_producer.record_event_stats(CounterConstants.EVENTS_DROPPED, len(events) - events_stored) + self._telemetry_runtime_producer.record_event_stats(CounterConstants.EVENTS_QUEUED, events_stored) if self._queue_full_hook is not None and callable(self._queue_full_hook): self._queue_full_hook() _LOGGER.warning( @@ -419,3 +1517,729 @@ def clear(self): """ with self._lock: self._events = queue.Queue(maxsize=self._queue_size) + + +class InMemoryEventStorageAsync(InMemoryEventStorageBase): + """ + In memory async storage for events. + Supports adding and popping events. + """ + def __init__(self, eventsQueueSize, telemetry_runtime_producer): + """ + Construct an instance. + + :param eventsQueueSize: How many events to queue before forcing a submission + """ + self._queue_size = eventsQueueSize + self._lock = asyncio.Lock() + self._events = asyncio.Queue(maxsize=eventsQueueSize) + self._queue_full_hook = None + self._size = 0 + self._telemetry_runtime_producer = telemetry_runtime_producer + + async def put(self, events): + """ + Add an event to storage. + + :param event: Event to be added in the storage + """ + events_stored = 0 + try: + async with self._lock: + for event in events: + if self._events.qsize() == self._queue_size: + raise asyncio.QueueFull + + self._size += event.size + if self._size >= MAX_SIZE_BYTES: + await self._queue_full_hook() + return False + + await self._events.put(event.event) + events_stored += 1 + await self._telemetry_runtime_producer.record_event_stats(CounterConstants.EVENTS_QUEUED, len(events)) + return True + + except asyncio.QueueFull: + await self._telemetry_runtime_producer.record_event_stats(CounterConstants.EVENTS_DROPPED, len(events) - events_stored) + await self._telemetry_runtime_producer.record_event_stats(CounterConstants.EVENTS_QUEUED, events_stored) + if self._queue_full_hook is not None and callable(self._queue_full_hook): + await self._queue_full_hook() + _LOGGER.warning( + 'Events queue is full, failing to add more events. \n' + 'Consider increasing parameter `eventsQueueSize` in configuration' + ) + return False + + async def pop_many(self, count): + """ + Pop multiple items from the storage. + + :param count: number of items to be retrieved and removed from the queue. + """ + events = [] + async with self._lock: + while not self._events.empty() and count > 0: + events.append(await self._events.get()) + count -= 1 + self._size = 0 + return events + + async def clear(self): + """ + Clear data. + """ + async with self._lock: + self._events = asyncio.Queue(maxsize=self._queue_size) + + +class InMemoryTelemetryStorageBase(TelemetryStorage): + """In-memory telemetry storage base.""" + + def _reset_tags(self): + self._tags = [] + + def _reset_config_tags(self): + self._config_tags = [] + + def record_config(self, config, extra_config, total_flag_sets, invalid_flag_sets): + """Record configurations.""" + pass + + def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """Record active and redundant factories.""" + pass + + def record_ready_time(self, ready_time): + """Record ready time.""" + pass + + def add_tag(self, tag): + """Record tag string.""" + pass + + def add_config_tag(self, tag): + """Record tag string.""" + pass + + def record_bur_time_out(self): + """Record block until ready timeout.""" + pass + + def record_not_ready_usage(self): + """record non-ready usage.""" + pass + + def record_latency(self, method, latency): + """Record method latency time.""" + pass + + def record_exception(self, method): + """Record method exception.""" + pass + + def record_impression_stats(self, data_type, count): + """Record impressions stats.""" + pass + + def record_event_stats(self, data_type, count): + """Record events stats.""" + pass + + def record_successful_sync(self, resource, time): + """Record successful sync.""" + pass + + def record_sync_error(self, resource, status): + """Record sync http error.""" + pass + + def record_sync_latency(self, resource, latency): + """Record latency time.""" + pass + + def record_auth_rejections(self): + """Record auth rejection.""" + pass + + def record_token_refreshes(self): + """Record sse token refresh.""" + pass + + def record_streaming_event(self, streaming_event): + """Record incoming streaming event.""" + pass + + def record_session_length(self, session): + """Record session length.""" + pass + + def record_update_from_sse(self, event): + """Record update from sse.""" + pass + + def get_bur_time_outs(self): + """Get block until ready timeout.""" + pass + + def get_non_ready_usage(self): + """Get non-ready usage.""" + pass + + def get_config_stats(self): + """Get all config info.""" + pass + + def pop_exceptions(self): + """Get and reset method exceptions.""" + pass + + def pop_tags(self): + """Get and reset tags.""" + pass + + def pop_config_tags(self): + """Get and reset tags.""" + pass + + def pop_latencies(self): + """Get and reset eval latencies.""" + pass + + def get_impressions_stats(self, type): + """Get impressions stats""" + pass + + def get_events_stats(self, type): + """Get events stats""" + pass + + def get_last_synchronization(self): + """Get last sync""" + pass + + def pop_http_errors(self): + """Get and reset http errors.""" + pass + + def pop_http_latencies(self): + """Get and reset http latencies.""" + pass + + def pop_auth_rejections(self): + """Get and reset auth rejections.""" + pass + + def pop_token_refreshes(self): + """Get and reset token refreshes.""" + pass + + def pop_streaming_events(self): + """Get and reset streaming events""" + pass + + def get_session_length(self): + """Get session length""" + pass + + def pop_update_from_sse(self, event): + """Get and reset update from sse.""" + pass + +class InMemoryTelemetryStorage(InMemoryTelemetryStorageBase): + """In-memory telemetry storage.""" + + def __init__(self): + """Constructor""" + self._lock = threading.RLock() + self._method_exceptions = MethodExceptions() + self._last_synchronization = LastSynchronization() + self._counters = TelemetryCounters() + self._http_sync_errors = HTTPErrors() + self._method_latencies = MethodLatencies() + self._http_latencies = HTTPLatencies() + self._streaming_events = StreamingEvents() + self._tel_config = TelemetryConfig() + with self._lock: + self._reset_tags() + self._reset_config_tags() + + def record_config(self, config, extra_config, total_flag_sets, invalid_flag_sets): + """Record configurations.""" + self._tel_config.record_config(config, extra_config, total_flag_sets, invalid_flag_sets) + + def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """Record active and redundant factories.""" + self._tel_config.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + + def record_ready_time(self, ready_time): + """Record ready time.""" + self._tel_config.record_ready_time(ready_time) + + def add_tag(self, tag): + """Record tag string.""" + with self._lock: + if len(self._tags) < MAX_TAGS: + self._tags.append(tag) + + def add_config_tag(self, tag): + """Record tag string.""" + with self._lock: + if len(self._config_tags) < MAX_TAGS: + self._config_tags.append(tag) + + def record_bur_time_out(self): + """Record block until ready timeout.""" + self._tel_config.record_bur_time_out() + + def record_not_ready_usage(self): + """record non-ready usage.""" + self._tel_config.record_not_ready_usage() + + def record_latency(self, method, latency): + """Record method latency time.""" + self._method_latencies.add_latency(method,latency) + + def record_exception(self, method): + """Record method exception.""" + self._method_exceptions.add_exception(method) + + def record_impression_stats(self, data_type, count): + """Record impressions stats.""" + self._counters.record_impressions_value(data_type, count) + + def record_event_stats(self, data_type, count): + """Record events stats.""" + self._counters.record_events_value(data_type, count) + + def record_successful_sync(self, resource, time): + """Record successful sync.""" + self._last_synchronization.add_latency(resource, time) + + def record_sync_error(self, resource, status): + """Record sync http error.""" + self._http_sync_errors.add_error(resource, status) + + def record_sync_latency(self, resource, latency): + """Record latency time.""" + self._http_latencies.add_latency(resource, latency) + + def record_auth_rejections(self): + """Record auth rejection.""" + self._counters.record_auth_rejections() + + def record_token_refreshes(self): + """Record sse token refresh.""" + self._counters.record_token_refreshes() + + def record_streaming_event(self, streaming_event): + """Record incoming streaming event.""" + self._streaming_events.record_streaming_event(streaming_event) + + def record_session_length(self, session): + """Record session length.""" + self._counters.record_session_length(session) + + def record_update_from_sse(self, event): + """Record update from sse.""" + self._counters.record_update_from_sse(event) + + def get_bur_time_outs(self): + """Get block until ready timeout.""" + return self._tel_config.get_bur_time_outs() + + def get_non_ready_usage(self): + """Get non-ready usage.""" + return self._tel_config.get_non_ready_usage() + + def get_config_stats(self): + """Get all config info.""" + return self._tel_config.get_stats() + + def pop_exceptions(self): + """Get and reset method exceptions.""" + return self._method_exceptions.pop_all() + + def pop_tags(self): + """Get and reset tags.""" + with self._lock: + tags = self._tags + self._reset_tags() + return tags + + def pop_config_tags(self): + """Get and reset tags.""" + with self._lock: + tags = self._config_tags + self._reset_config_tags() + return tags + + def pop_latencies(self): + """Get and reset eval latencies.""" + return self._method_latencies.pop_all() + + def get_impressions_stats(self, type): + """Get impressions stats""" + return self._counters.get_counter_stats(type) + + def get_events_stats(self, type): + """Get events stats""" + return self._counters.get_counter_stats(type) + + def get_last_synchronization(self): + """Get last sync""" + return self._last_synchronization.get_all() + + def pop_http_errors(self): + """Get and reset http errors.""" + return self._http_sync_errors.pop_all() + + def pop_http_latencies(self): + """Get and reset http latencies.""" + return self._http_latencies.pop_all() + + def pop_auth_rejections(self): + """Get and reset auth rejections.""" + return self._counters.pop_auth_rejections() + + def pop_token_refreshes(self): + """Get and reset token refreshes.""" + return self._counters.pop_token_refreshes() + + def pop_streaming_events(self): + return self._streaming_events.pop_streaming_events() + + def get_session_length(self): + """Get session length""" + return self._counters.get_session_length() + + def pop_update_from_sse(self, event): + """Get and reset update from sse.""" + return self._counters.pop_update_from_sse(event) + +class InMemoryTelemetryStorageAsync(InMemoryTelemetryStorageBase): + """In-memory telemetry async storage.""" + + @classmethod + async def create(cls): + """Constructor""" + self = cls() + self._lock = asyncio.Lock() + self._method_exceptions = await MethodExceptionsAsync.create() + self._last_synchronization = await LastSynchronizationAsync.create() + self._counters = await TelemetryCountersAsync.create() + self._http_sync_errors = await HTTPErrorsAsync.create() + self._method_latencies = await MethodLatenciesAsync.create() + self._http_latencies = await HTTPLatenciesAsync.create() + self._streaming_events = await StreamingEventsAsync.create() + self._tel_config = await TelemetryConfigAsync.create() + async with self._lock: + self._reset_tags() + self._reset_config_tags() + return self + + async def record_config(self, config, extra_config, total_flag_sets, invalid_flag_sets): + """Record configurations.""" + await self._tel_config.record_config(config, extra_config, total_flag_sets, invalid_flag_sets) + + async def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """Record active and redundant factories.""" + await self._tel_config.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + + async def record_ready_time(self, ready_time): + """Record ready time.""" + await self._tel_config.record_ready_time(ready_time) + + async def add_tag(self, tag): + """Record tag string.""" + async with self._lock: + if len(self._tags) < MAX_TAGS: + self._tags.append(tag) + + async def add_config_tag(self, tag): + """Record tag string.""" + async with self._lock: + if len(self._config_tags) < MAX_TAGS: + self._config_tags.append(tag) + + async def record_bur_time_out(self): + """Record block until ready timeout.""" + await self._tel_config.record_bur_time_out() + + async def record_not_ready_usage(self): + """record non-ready usage.""" + await self._tel_config.record_not_ready_usage() + + async def record_latency(self, method, latency): + """Record method latency time.""" + await self._method_latencies.add_latency(method,latency) + + async def record_exception(self, method): + """Record method exception.""" + await self._method_exceptions.add_exception(method) + + async def record_impression_stats(self, data_type, count): + """Record impressions stats.""" + await self._counters.record_impressions_value(data_type, count) + + async def record_event_stats(self, data_type, count): + """Record events stats.""" + await self._counters.record_events_value(data_type, count) + + async def record_successful_sync(self, resource, time): + """Record successful sync.""" + await self._last_synchronization.add_latency(resource, time) + + async def record_sync_error(self, resource, status): + """Record sync http error.""" + await self._http_sync_errors.add_error(resource, status) + + async def record_sync_latency(self, resource, latency): + """Record latency time.""" + await self._http_latencies.add_latency(resource, latency) + + async def record_auth_rejections(self): + """Record auth rejection.""" + await self._counters.record_auth_rejections() + + async def record_token_refreshes(self): + """Record sse token refresh.""" + await self._counters.record_token_refreshes() + + async def record_streaming_event(self, streaming_event): + """Record incoming streaming event.""" + await self._streaming_events.record_streaming_event(streaming_event) + + async def record_session_length(self, session): + """Record session length.""" + await self._counters.record_session_length(session) + + async def record_update_from_sse(self, event): + """Record update from sse.""" + await self._counters.record_update_from_sse(event) + + async def get_bur_time_outs(self): + """Get block until ready timeout.""" + return await self._tel_config.get_bur_time_outs() + + async def get_non_ready_usage(self): + """Get non-ready usage.""" + return await self._tel_config.get_non_ready_usage() + + async def get_config_stats(self): + """Get all config info.""" + return await self._tel_config.get_stats() + + async def pop_exceptions(self): + """Get and reset method exceptions.""" + return await self._method_exceptions.pop_all() + + async def pop_tags(self): + """Get and reset tags.""" + async with self._lock: + tags = self._tags + self._reset_tags() + return tags + + async def pop_config_tags(self): + """Get and reset tags.""" + async with self._lock: + tags = self._config_tags + self._reset_config_tags() + return tags + + async def pop_latencies(self): + """Get and reset eval latencies.""" + return await self._method_latencies.pop_all() + + async def get_impressions_stats(self, type): + """Get impressions stats""" + return await self._counters.get_counter_stats(type) + + async def get_events_stats(self, type): + """Get events stats""" + return await self._counters.get_counter_stats(type) + + async def get_last_synchronization(self): + """Get last sync""" + return await self._last_synchronization.get_all() + + async def pop_http_errors(self): + """Get and reset http errors.""" + return await self._http_sync_errors.pop_all() + + async def pop_http_latencies(self): + """Get and reset http latencies.""" + return await self._http_latencies.pop_all() + + async def pop_auth_rejections(self): + """Get and reset auth rejections.""" + return await self._counters.pop_auth_rejections() + + async def pop_token_refreshes(self): + """Get and reset token refreshes.""" + return await self._counters.pop_token_refreshes() + + async def pop_streaming_events(self): + return await self._streaming_events.pop_streaming_events() + + async def get_session_length(self): + """Get session length""" + return await self._counters.get_session_length() + + async def pop_update_from_sse(self, event): + """Get and reset update from sse.""" + return await self._counters.pop_update_from_sse(event) + +class LocalhostTelemetryStorage(): + """Localhost telemetry storage.""" + def do_nothing(*_, **__): + return {} + + def __getattr__(self, _): + return self.do_nothing + +class LocalhostTelemetryStorageAsync(): + """Localhost telemetry storage.""" + + async def record_ready_time(self, ready_time): + pass + + async def record_config(self, config, extra_config): + """Record configurations.""" + pass + + async def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """Record active and redundant factories.""" + pass + + async def add_tag(self, tag): + """Record tag string.""" + pass + + async def add_config_tag(self, tag): + """Record tag string.""" + pass + + async def record_bur_time_out(self): + """Record block until ready timeout.""" + pass + + async def record_not_ready_usage(self): + """record non-ready usage.""" + pass + + async def record_latency(self, method, latency): + """Record method latency time.""" + pass + + async def record_exception(self, method): + """Record method exception.""" + pass + + async def record_impression_stats(self, data_type, count): + """Record impressions stats.""" + pass + + async def record_event_stats(self, data_type, count): + """Record events stats.""" + pass + + async def record_successful_sync(self, resource, time): + """Record successful sync.""" + pass + + async def record_sync_error(self, resource, status): + """Record sync http error.""" + pass + + async def record_sync_latency(self, resource, latency): + """Record latency time.""" + pass + + async def record_auth_rejections(self): + """Record auth rejection.""" + pass + + async def record_token_refreshes(self): + """Record sse token refresh.""" + pass + + async def record_streaming_event(self, streaming_event): + """Record incoming streaming event.""" + pass + + async def record_session_length(self, session): + """Record session length.""" + pass + + async def record_update_from_sse(self, event): + """Record update from sse.""" + pass + + async def get_bur_time_outs(self): + """Get block until ready timeout.""" + pass + + async def get_non_ready_usage(self): + """Get non-ready usage.""" + pass + + async def get_config_stats(self): + """Get all config info.""" + pass + + async def pop_exceptions(self): + """Get and reset method exceptions.""" + pass + + async def pop_tags(self): + """Get and reset tags.""" + pass + + async def pop_config_tags(self): + """Get and reset tags.""" + pass + + async def pop_latencies(self): + """Get and reset eval latencies.""" + pass + + async def get_impressions_stats(self, type): + """Get impressions stats""" + pass + + async def get_events_stats(self, type): + """Get events stats""" + pass + + async def get_last_synchronization(self): + """Get last sync""" + pass + + async def pop_http_errors(self): + """Get and reset http errors.""" + pass + + async def pop_http_latencies(self): + """Get and reset http latencies.""" + pass + + async def pop_auth_rejections(self): + """Get and reset auth rejections.""" + pass + + async def pop_token_refreshes(self): + """Get and reset token refreshes.""" + pass + + async def pop_streaming_events(self): + pass + + async def get_session_length(self): + """Get session length""" + pass + + async def pop_update_from_sse(self, event): + """Get and reset update from sse.""" + pass \ No newline at end of file diff --git a/splitio/storage/pluggable.py b/splitio/storage/pluggable.py new file mode 100644 index 00000000..71e487c6 --- /dev/null +++ b/splitio/storage/pluggable.py @@ -0,0 +1,1977 @@ +"""Pluggable Storage classes.""" + +import logging +import json +import threading + +from splitio.optional.loaders import asyncio +from splitio.models import splits, segments, rule_based_segments +from splitio.models.impressions import Impression +from splitio.models.telemetry import MethodExceptions, MethodLatencies, TelemetryConfig, MAX_TAGS,\ + MethodLatenciesAsync, MethodExceptionsAsync, TelemetryConfigAsync +from splitio.storage import FlagSetsFilter, SplitStorage, SegmentStorage, ImpressionStorage, EventStorage, TelemetryStorage, RuleBasedSegmentsStorage +from splitio.util.storage_helper import get_valid_flag_sets, combine_valid_flag_sets + +_LOGGER = logging.getLogger(__name__) + +class PluggableRuleBasedSegmentsStorageBase(RuleBasedSegmentsStorage): + """Pluggable storage for rule based segments.""" + + _TILL_LENGTH = 4 + + def __init__(self, pluggable_adapter, prefix=None): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + """ + self._pluggable_adapter = pluggable_adapter + self._prefix = "SPLITIO.rbsegment.{segment_name}" + self._rb_segments_till_prefix = "SPLITIO.rbsegments.till" + self._rb_segment_name_length = 18 + if prefix is not None: + self._rb_segment_name_length += len(prefix) + 1 + self._prefix = prefix + "." + self._prefix + self._rb_segments_till_prefix = prefix + "." + self._rb_segments_till_prefix + + def get(self, segment_name): + """ + Retrieve a rule based segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str + + :rtype: str + """ + pass + + def get_change_number(self): + """ + Retrieve latest rule based segment change number. + + :rtype: int + """ + pass + + def contains(self, segment_names): + """ + Return whether the segments exists in rule based segment in cache. + + :param segment_names: segment name to validate. + :type segment_names: str + + :return: True if segment names exists. False otherwise. + :rtype: bool + """ + pass + + def get_segment_names(self): + """ + Retrieve a list of all excluded segments names. + + :return: List of segment names. + :rtype: list(str) + """ + pass + + def update(self, to_add, to_delete, new_change_number): + """ + Update rule based segment.. + + :param to_add: List of rule based segment. to add + :type to_add: list[splitio.models.rule_based_segments.RuleBasedSegment] + :param to_delete: List of rule based segment. to delete + :type to_delete: list[splitio.models.rule_based_segments.RuleBasedSegment] + :param new_change_number: New change number. + :type new_change_number: int + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def get_large_segment_names(self): + """ + Retrieve a list of all excluded large segments names. + + :return: List of segment names. + :rtype: list(str) + """ + pass + +class PluggableRuleBasedSegmentsStorage(PluggableRuleBasedSegmentsStorageBase): + """Pluggable storage for rule based segments.""" + + def __init__(self, pluggable_adapter, prefix=None): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + """ + PluggableRuleBasedSegmentsStorageBase.__init__(self, pluggable_adapter, prefix) + + def get(self, segment_name): + """ + Retrieve a rule based segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str + + :rtype: str + """ + try: + rb_segment = self._pluggable_adapter.get(self._prefix.format(segment_name=segment_name)) + if not rb_segment: + return None + + return rule_based_segments.from_raw(rb_segment) + + except Exception: + _LOGGER.error('Error getting rule based segment from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def get_change_number(self): + """ + Retrieve latest rule based segment change number. + + :rtype: int + """ + try: + return self._pluggable_adapter.get(self._rb_segments_till_prefix) + + except Exception: + _LOGGER.error('Error getting change number in rule based segment storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def contains(self, segment_names): + """ + Return whether the segments exists in rule based segment in cache. + + :param segment_names: segment name to validate. + :type segment_names: str + + :return: True if segment names exists. False otherwise. + :rtype: bool + """ + return set(segment_names).issubset(self.get_segment_names()) + + def get_segment_names(self): + """ + Retrieve a list of all rule based segments names. + + :return: List of segment names. + :rtype: list(str) + """ + try: + _LOGGER.error(self._rb_segment_name_length) + _LOGGER.error(self._prefix) + _LOGGER.error(self._prefix[:self._rb_segment_name_length]) + keys = [] + for key in self._pluggable_adapter.get_keys_by_prefix(self._prefix[:self._rb_segment_name_length]): + if key[-self._TILL_LENGTH:] != 'till': + keys.append(key[len(self._prefix[:self._rb_segment_name_length]):]) + return keys + + except Exception: + _LOGGER.error('Error getting rule based segments names from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def fetch_many(self, rb_segment_names): + """ + Retrieve rule based segments. + + :param rb_segment_names: Names of the rule based segments to fetch. + :type rb_segment_names: list(str) + + :return: A dict with rule based segment objects parsed from queue. + :rtype: dict(rb_segment_names, splitio.models.rile_based_segment.RuleBasedSegment) + """ + try: + prefix_added = [self._prefix.format(segment_name=rb_segment_name) for rb_segment_name in rb_segment_names] + return {rb_segment['name']: rule_based_segments.from_raw(rb_segment) for rb_segment in self._pluggable_adapter.get_many(prefix_added)} + + except Exception: + _LOGGER.error('Error getting rule based segments from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + +class PluggableRuleBasedSegmentsStorageAsync(PluggableRuleBasedSegmentsStorageBase): + """Pluggable storage for rule based segments.""" + + def __init__(self, pluggable_adapter, prefix=None): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + """ + PluggableRuleBasedSegmentsStorageBase.__init__(self, pluggable_adapter, prefix) + + async def get(self, segment_name): + """ + Retrieve a rule based segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str + + :rtype: str + """ + try: + rb_segment = await self._pluggable_adapter.get(self._prefix.format(segment_name=segment_name)) + if not rb_segment: + return None + + return rule_based_segments.from_raw(rb_segment) + + except Exception: + _LOGGER.error('Error getting rule based segment from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def get_change_number(self): + """ + Retrieve latest rule based segment change number. + + :rtype: int + """ + try: + return await self._pluggable_adapter.get(self._rb_segments_till_prefix) + + except Exception: + _LOGGER.error('Error getting change number in rule based segment storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def contains(self, segment_names): + """ + Return whether the segments exists in rule based segment in cache. + + :param segment_names: segment name to validate. + :type segment_names: str + + :return: True if segment names exists. False otherwise. + :rtype: bool + """ + return set(segment_names).issubset(await self.get_segment_names()) + + async def get_segment_names(self): + """ + Retrieve a list of all rule based segments names. + + :return: List of segment names. + :rtype: list(str) + """ + try: + keys = [] + for key in await self._pluggable_adapter.get_keys_by_prefix(self._prefix[:self._rb_segment_name_length]): + if key[-self._TILL_LENGTH:] != 'till': + keys.append(key[len(self._prefix[:self._rb_segment_name_length]):]) + return keys + + except Exception: + _LOGGER.error('Error getting rule based segments names from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def fetch_many(self, rb_segment_names): + """ + Retrieve rule based segments. + + :param rb_segment_names: Names of the rule based segments to fetch. + :type rb_segment_names: list(str) + + :return: A dict with rule based segment objects parsed from queue. + :rtype: dict(rb_segment_names, splitio.models.rile_based_segment.RuleBasedSegment) + """ + try: + prefix_added = [self._prefix.format(segment_name=rb_segment_name) for rb_segment_name in rb_segment_names] + return {rb_segment['name']: rule_based_segments.from_raw(rb_segment) for rb_segment in await self._pluggable_adapter.get_many(prefix_added)} + + except Exception: + _LOGGER.error('Error getting rule based segments from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + +class PluggableSplitStorageBase(SplitStorage): + """InMemory implementation of a feature flag storage.""" + + _FEATURE_FLAG_NAME_LENGTH = 19 + _TILL_LENGTH = 4 + + def __init__(self, pluggable_adapter, prefix=None, config_flag_sets=[]): + """ + Class constructor. + + :param pluggable_adapter: Storage client or compliant interface. + :type pluggable_adapter: TBD + :param prefix: optional, prefix to storage keys + :type prefix: str + """ + self._pluggable_adapter = pluggable_adapter + self._prefix = "SPLITIO.split.{feature_flag_name}" + self._traffic_type_prefix = "SPLITIO.trafficType.{traffic_type_name}" + self._feature_flag_till_prefix = "SPLITIO.splits.till" + self._flag_set_prefix = 'SPLITIO.flagSet.{flag_set}' + self.flag_set_filter = FlagSetsFilter(config_flag_sets) + if prefix is not None: + self._prefix = prefix + "." + self._prefix + self._traffic_type_prefix = prefix + "." + self._traffic_type_prefix + self._feature_flag_till_prefix = prefix + "." + self._feature_flag_till_prefix + self._flag_set_prefix = prefix + "." + self._flag_set_prefix + + def get(self, feature_flag_name): + """ + Retrieve a feature flag. + + :param feature_flag_name: Name of the feature to fetch. + :type feature_flag_name: str + + :rtype: splitio.models.splits.Split + """ + pass + + def fetch_many(self, feature_flag_names): + """ + Retrieve feature flags. + + :param feature_flag_names: Names of the features to fetch. + :type feature_flag_name: list(str) + + :return: A dict with feature flag objects parsed from queue. + :rtype: dict(feature_flag_name, splitio.models.splits.Split) + """ + pass + + # TODO: To be added when producer mode is supported +# def put_many(self, splits, change_number): +# """ +# Store multiple splits. +# +# :param split: array of Split objects. +# :type split: splitio.models.split.Split[] +# """ +# try: +# for split in splits: +# self.put(split) +# self._pluggable_adapter.set(self._split_till_prefix, change_number) +# except Exception: +# _LOGGER.error('Error storing splits in storage') +# _LOGGER.debug('Error: ', exc_info=True) + + def update(self, to_add, to_delete, new_change_number): + """ + Update feature flag storage. + :param to_add: List of feature flags to add + :type to_add: list[splitio.models.splits.Split] + :param to_delete: List of feature flags to delete + :type to_delete: list[splitio.models.splits.Split] + :param new_change_number: New change number. + :type new_change_number: int + """ + pass +# try: +# split = self.get(feature_flag_name) +# if not split: +# _LOGGER.warning("Tried to delete nonexistant split %s. Skipping", feature_flag_name) +# return False +# self._pluggable_adapter.delete(self._prefix.format(feature_flag_name=feature_flag_name)) +# self._decrease_traffic_type_count(split.traffic_type_name) +# return True +# except Exception: +# _LOGGER.error('Error removing split from storage') +# _LOGGER.debug('Error: ', exc_info=True) +# return False + + def get_change_number(self): + """ + Retrieve latest feature flag change number. + + :rtype: int + """ + pass + + # TODO: To be added when producer mode is aupported +# def _set_change_number(self, new_change_number): + """ + Set the latest change number. + + :param new_change_number: New change number. + :type new_change_number: int + """ +# pass +# try: +# self._pluggable_adapter.set(self._split_till_prefix, new_change_number) +# except Exception: +# _LOGGER.error('Error setting change number in split storage') +# _LOGGER.debug('Error: ', exc_info=True) +# return None + + def get_split_names(self): + """ + Retrieve a list of all feature flag names. + + :return: List of feature flag names. + :rtype: list(str) + """ + pass + + def traffic_type_exists(self, traffic_type_name): + """ + Return whether the traffic type exists in at least one feature flag in cache. + + :param traffic_type_name: Traffic type to validate. + :type traffic_type_name: str + + :return: True if the traffic type is valid. False otherwise. + :rtype: bool + """ + pass + + def kill_locally(self, feature_flag_name, default_treatment, change_number): + """ + Local kill for feature flag + + :param feature_flag_name: name of the feature flag to perform kill + :type feature_flag_name: str + :param default_treatment: name of the default treatment to return + :type default_treatment: str + :param change_number: change_number + :type change_number: int + """ + pass + # TODO: To be added when producer mode is aupported +# try: +# split = self.get(feature_flag_name) +# if not split: +# return +# if self.get_change_number() > change_number: +# return +# split.local_kill(default_treatment, change_number) +# self._pluggable_adapter.set(self._prefix.format(feature_flag_name=feature_flag_name), split.to_json()) +# except Exception: +# _LOGGER.error('Error updating split in storage') +# _LOGGER.debug('Error: ', exc_info=True) + + # TODO: To be added when producer mode is aupported +# def _increase_traffic_type_count(self, traffic_type_name): +# """ +# Increase by one the count for a specific traffic type name. +# +# :param traffic_type_name: Traffic type to increase the count. +# :type traffic_type_name: str +# +# :return: existing count of traffic type +# :rtype: int +# """ +# try: +# return self._pluggable_adapter.increment(self._traffic_type_prefix.format(traffic_type_name=traffic_type_name), 1) +# except Exception: +# _LOGGER.error('Error updating traffic type count in split storage') +# _LOGGER.debug('Error: ', exc_info=True) +# return None + + # TODO: To be added when producer mode is aupported +# def _decrease_traffic_type_count(self, traffic_type_name): +# """ +# Decrease by one the count for a specific traffic type name. +# +# :param traffic_type_name: Traffic type to decrease the count. +# :type traffic_type_name: str +# +# :return: existing count of traffic type +# :rtype: int +# """ +# try: +# return_count = self._pluggable_adapter.decrement(self._traffic_type_prefix.format(traffic_type_name=traffic_type_name), 1) +# if return_count == 0: +# self._pluggable_adapter.delete(self._traffic_type_prefix.format(traffic_type_name=traffic_type_name)) +# except Exception: +# _LOGGER.error('Error updating traffic type count in split storage') +# _LOGGER.debug('Error: ', exc_info=True) +# return None + + def get_all_splits(self): + """ + Return all the feature flags. + + :return: List of all the feature flags. + :rtype: list + """ + pass + + def is_valid_traffic_type(self, traffic_type_name): + """ + Return whether the traffic type exists in at least one feature flag in cache. + + :param traffic_type_name: Traffic type to validate. + :type traffic_type_name: str + + :return: True if the traffic type is valid. False otherwise. + :rtype: bool + """ + pass + +class PluggableSplitStorage(PluggableSplitStorageBase): + """InMemory implementation of a feature flag storage.""" + + def __init__(self, pluggable_adapter, prefix=None, config_flag_sets=[]): + """ + Class constructor. + + :param pluggable_adapter: Storage client or compliant interface. + :type pluggable_adapter: TBD + :param prefix: optional, prefix to storage keys + :type prefix: str + """ + PluggableSplitStorageBase.__init__(self, pluggable_adapter, prefix) + + def get(self, feature_flag_name): + """ + Retrieve a feature flag. + + :param feature_flag_name: Name of the feature to fetch. + :type feature_flag_name: str + + :rtype: splitio.models.splits.Split + """ + try: + feature_flag = self._pluggable_adapter.get(self._prefix.format(feature_flag_name=feature_flag_name)) + if not feature_flag: + return None + + return splits.from_raw(feature_flag) + + except Exception: + _LOGGER.error('Error getting feature flag from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def fetch_many(self, feature_flag_names): + """ + Retrieve feature flags. + + :param feature_flag_names: Names of the features to fetch. + :type feature_flag_name: list(str) + + :return: A dict with feature flag objects parsed from queue. + :rtype: dict(feature_flag_name, splitio.models.splits.Split) + """ + try: + prefix_added = [self._prefix.format(feature_flag_name=feature_flag_name) for feature_flag_name in feature_flag_names] + return {feature_flag['name']: splits.from_raw(feature_flag) for feature_flag in self._pluggable_adapter.get_many(prefix_added)} + + except Exception: + _LOGGER.error('Error getting feature flag from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def get_feature_flags_by_sets(self, flag_sets): + """ + Retrieve feature flags by flag set. + :param flag_sets: List of flag sets to fetch. + :type flag_sets: list(str) + :return: Feature flag names that are tagged with the flag set + :rtype: listt(str) + """ + try: + sets_to_fetch = get_valid_flag_sets(flag_sets, self.flag_set_filter) + if sets_to_fetch == []: + return [] + + keys = [self._flag_set_prefix.format(flag_set=flag_set) for flag_set in sets_to_fetch] + result_sets = [] + [result_sets.append(set(key)) for key in self._pluggable_adapter.get_many(keys)] + return list(combine_valid_flag_sets(result_sets)) + + except Exception: + _LOGGER.error('Error fetching feature flag from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def get_change_number(self): + """ + Retrieve latest feature flag change number. + + :rtype: int + """ + try: + return self._pluggable_adapter.get(self._feature_flag_till_prefix) + + except Exception: + _LOGGER.error('Error getting change number in feature flag storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def get_split_names(self): + """ + Retrieve a list of all feature flag names. + + :return: List of feature flag names. + :rtype: list(str) + """ + try: + keys = [] + for key in self._pluggable_adapter.get_keys_by_prefix(self._prefix[:-self._FEATURE_FLAG_NAME_LENGTH]): + if key[-self._TILL_LENGTH:] != 'till': + keys.append(key[len(self._prefix[:-self._FEATURE_FLAG_NAME_LENGTH]):]) + return keys + + except Exception: + _LOGGER.error('Error getting feature flag names from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def traffic_type_exists(self, traffic_type_name): + """ + Return whether the traffic type exists in at least one feature flag in cache. + + :param traffic_type_name: Traffic type to validate. + :type traffic_type_name: str + + :return: True if the traffic type is valid. False otherwise. + :rtype: bool + """ + try: + return self._pluggable_adapter.get(self._traffic_type_prefix.format(traffic_type_name=traffic_type_name)) != None + + except Exception: + _LOGGER.error('Error getting feature flag info from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def get_all_splits(self): + """ + Return all the feature flags. + + :return: List of all the feature flags. + :rtype: list + """ + try: + keys = [] + for key in self._pluggable_adapter.get_keys_by_prefix(self._prefix[:-self._FEATURE_FLAG_NAME_LENGTH]): + if key[-self._TILL_LENGTH:] != 'till': + keys.append(key) + return [splits.from_raw(feature_flag) for feature_flag in self._pluggable_adapter.get_many(keys)] + + except Exception: + _LOGGER.error('Error fetching feature flags from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def is_valid_traffic_type(self, traffic_type_name): + """ + Return whether the traffic type exists in at least one feature flag in cache. + + :param traffic_type_name: Traffic type to validate. + :type traffic_type_name: str + + :return: True if the traffic type is valid. False otherwise. + :rtype: bool + """ + try: + return self.traffic_type_exists(traffic_type_name) + + except Exception: + _LOGGER.error('Error getting traffic type info from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + +class PluggableSplitStorageAsync(PluggableSplitStorageBase): + """InMemory async implementation of a feature flag storage.""" + + def __init__(self, pluggable_adapter, prefix=None): + """ + Class constructor. + + :param pluggable_adapter: Storage client or compliant interface. + :type pluggable_adapter: TBD + :param prefix: optional, prefix to storage keys + :type prefix: str + """ + PluggableSplitStorageBase.__init__(self, pluggable_adapter, prefix) + + async def get(self, feature_flag_name): + """ + Retrieve a feature flag. + + :param feature_flag_name: Name of the feature to fetch. + :type feature_flag_name: str + + :rtype: splitio.models.splits.Split + """ + try: + feature_flag = await self._pluggable_adapter.get(self._prefix.format(feature_flag_name=feature_flag_name)) + if not feature_flag: + return None + + return splits.from_raw(feature_flag) + + except Exception: + _LOGGER.error('Error getting feature flag from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def fetch_many(self, feature_flag_names): + """ + Retrieve feature flags. + + :param feature_flag_names: Names of the features to fetch. + :type feature_flag_name: list(str) + + :return: A dict with feature_flag objects parsed from queue. + :rtype: dict(split_feature_flag, splitio.models.splits.Split) + """ + try: + prefix_added = [self._prefix.format(feature_flag_name=feature_flag_name) for feature_flag_name in feature_flag_names] + return {feature_flag['name']: splits.from_raw(feature_flag) for feature_flag in await self._pluggable_adapter.get_many(prefix_added)} + + except Exception: + _LOGGER.error('Error getting feature flag from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def get_feature_flags_by_sets(self, flag_sets): + """ + Retrieve feature flags by flag set. + :param flag_sets: List of flag sets to fetch. + :type flag_sets: list(str) + :return: Feature flag names that are tagged with the flag set + :rtype: listt(str) + """ + try: + sets_to_fetch = get_valid_flag_sets(flag_sets, self.flag_set_filter) + if sets_to_fetch == []: + return [] + + keys = [self._flag_set_prefix.format(flag_set=flag_set) for flag_set in sets_to_fetch] + result_sets = [] + [result_sets.append(set(key)) for key in await self._pluggable_adapter.get_many(keys)] + return list(combine_valid_flag_sets(result_sets)) + + except Exception: + _LOGGER.error('Error fetching feature flag from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def get_change_number(self): + """ + Retrieve latest feature flag change number. + + :rtype: int + """ + try: + return await self._pluggable_adapter.get(self._feature_flag_till_prefix) + + except Exception: + _LOGGER.error('Error getting change number in feature flag storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def get_split_names(self): + """ + Retrieve a list of all feature flag names. + + :return: List of feature flag names. + :rtype: list(str) + """ + try: + keys = [] + for key in await self._pluggable_adapter.get_keys_by_prefix(self._prefix[:-self._FEATURE_FLAG_NAME_LENGTH]): + if key[-self._TILL_LENGTH:] != 'till': + keys.append(key[len(self._prefix[:-self._FEATURE_FLAG_NAME_LENGTH]):]) + return keys + + except Exception: + _LOGGER.error('Error getting feature flag names from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def traffic_type_exists(self, traffic_type_name): + """ + Return whether the traffic type exists in at least one feature flag in cache. + + :param traffic_type_name: Traffic type to validate. + :type traffic_type_name: str + + :return: True if the traffic type is valid. False otherwise. + :rtype: bool + """ + try: + return await self._pluggable_adapter.get(self._traffic_type_prefix.format(traffic_type_name=traffic_type_name)) != None + + except Exception: + _LOGGER.error('Error getting traffic type info from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def get_all_splits(self): + """ + Return all the feature flags. + + :return: List of all the feature flags. + :rtype: list + """ + try: + keys = [] + for key in await self._pluggable_adapter.get_keys_by_prefix(self._prefix[:-self._FEATURE_FLAG_NAME_LENGTH]): + if key[-self._TILL_LENGTH:] != 'till': + keys.append(key) + return [splits.from_raw(feature_flag) for feature_flag in await self._pluggable_adapter.get_many(keys)] + + except Exception: + _LOGGER.error('Error fetching feature flags from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def is_valid_traffic_type(self, traffic_type_name): + """ + Return whether the traffic type exists in at least one feature flag in cache. + + :param traffic_type_name: Traffic type to validate. + :type traffic_type_name: str + + :return: True if the traffic type is valid. False otherwise. + :rtype: bool + """ + try: + return await self.traffic_type_exists(traffic_type_name) + + except Exception: + _LOGGER.error('Error getting feature flag info from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + +class PluggableSegmentStorageBase(SegmentStorage): + """Pluggable async implementation of segment storage.""" + _SEGMENT_NAME_LENGTH = 14 + _TILL_LENGTH = 4 + + def __init__(self, pluggable_adapter, prefix=None): + """ + Class constructor. + + :param pluggable_adapter: Storage client or compliant interface. + :type pluggable_adapter: TBD + :param prefix: optional, prefix to storage keys + :type prefix: str + """ + self._pluggable_adapter = pluggable_adapter + self._prefix = "SPLITIO.segment.{segment_name}" + self._segment_till_prefix = "SPLITIO.segment.{segment_name}.till" + if prefix is not None: + self._prefix = prefix + "." + self._prefix + self._segment_till_prefix = prefix + "." + self._segment_till_prefix + + def update(self, segment_name, to_add, to_remove, change_number=None): + """ + Update a segment. Create it if it doesn't exist. + + :param segment_name: Name of the segment to update. + :type segment_name: str + :param to_add: Set of members to add to the segment. + :type to_add: set + :param to_remove: List of members to remove from the segment. + :type to_remove: Set + """ + pass + # TODO: To be added when producer mode is aupported +# try: +# if to_add is not None: +# self._pluggable_adapter.add_items(self._prefix.format(segment_name=segment_name), to_add) +# if to_remove is not None: +# self._pluggable_adapter.remove_items(self._prefix.format(segment_name=segment_name), to_remove) +# if change_number is not None: +# self._pluggable_adapter.set(self._segment_till_prefix.format(segment_name=segment_name), change_number) +# except Exception: +# _LOGGER.error('Error updating segment storage') +# _LOGGER.debug('Error: ', exc_info=True) + + def set_change_number(self, segment_name, change_number): + """ + Store a segment change number. + + :param segment_name: segment name + :type segment_name: str + :param change_number: change number + :type segment_name: int + """ + pass + # TODO: To be added when producer mode is aupported +# try: +# self._pluggable_adapter.set(self._segment_till_prefix.format(segment_name=segment_name), change_number) +# except Exception: +# _LOGGER.error('Error updating segment change number') +# _LOGGER.debug('Error: ', exc_info=True) + + def get_change_number(self, segment_name): + """ + Get a segment change number. + + :param segment_name: segment name + :type segment_name: str + + :return: change number + :rtype: int + """ + pass + + def get_segment_names(self): + """ + Get list of segment names. + + :return: list of segment names + :rtype: str[] + """ + pass + + # TODO: To be added in the future because this data is not being sent by telemetry in consumer/synchronizer mode +# def get_keys(self, segment_name): +# """ +# Get keys of a segment. +# +# :param segment_name: segment name +# :type segment_name: str +# +# :return: list of segment keys +# :rtype: str[] +# """ +# try: +# return list(self._pluggable_adapter.get(self._prefix.format(segment_name=segment_name))) +# except Exception: +# _LOGGER.error('Error getting segments keys') +# _LOGGER.debug('Error: ', exc_info=True) +# return None + + def segment_contains(self, segment_name, key): + """ + Check if segment contains a key + + :param segment_name: segment name + :type segment_name: str + :param key: key + :type key: str + + :return: True if found, otherwise False + :rtype: bool + """ + pass + + def get_segment_keys_count(self): + """ + Get count of all keys in segments. + + :return: keys count + :rtype: int + """ + pass + # TODO: To be added when producer mode is aupported +# try: +# return sum([self._pluggable_adapter.get_items_count(key) for key in self._pluggable_adapter.get_keys_by_prefix(self._prefix)]) +# except Exception: +# _LOGGER.error('Error getting segment keys') +# _LOGGER.debug('Error: ', exc_info=True) +# return None + + def get(self, segment_name): + """ + Get a segment + + :param segment_name: segment name + :type segment_name: str + + :return: segment object + :rtype: splitio.models.segments.Segment + """ + pass + + def put(self, segment): + """ + Store a segment. + + :param segment: Segment to store. + :type segment: splitio.models.segment.Segment + """ + pass + # TODO: To be added when producer mode is aupported +# try: +# self._pluggable_adapter.add_items(self._prefix.format(segment_name=segment.name), list(segment.keys)) +# if segment.change_number is not None: +# self._pluggable_adapter.set(self._segment_till_prefix.format(segment_name=segment.name), segment.change_number) +# except Exception: +# _LOGGER.error('Error updating segment storage') +# _LOGGER.debug('Error: ', exc_info=True) + + +class PluggableSegmentStorage(PluggableSegmentStorageBase): + """Pluggable implementation of segment storage.""" + + def __init__(self, pluggable_adapter, prefix=None): + """ + Class constructor. + + :param pluggable_adapter: Storage client or compliant interface. + :type pluggable_adapter: TBD + :param prefix: optional, prefix to storage keys + :type prefix: str + """ + PluggableSegmentStorageBase.__init__(self, pluggable_adapter, prefix) + + def get_change_number(self, segment_name): + """ + Get a segment change number. + + :param segment_name: segment name + :type segment_name: str + + :return: change number + :rtype: int + """ + try: + return self._pluggable_adapter.get(self._segment_till_prefix.format(segment_name=segment_name)) + + except Exception: + _LOGGER.error('Error fetching segment change number') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def get_segment_names(self): + """ + Get list of segment names. + + :return: list of segment names + :rtype: str[] + """ + try: + keys = [] + for key in self._pluggable_adapter.get_keys_by_prefix(self._prefix[:-self._SEGMENT_NAME_LENGTH]): + if key[-self._TILL_LENGTH:] != 'till': + keys.append(key[len(self._prefix[:-self._SEGMENT_NAME_LENGTH]):]) + return keys + + except Exception: + _LOGGER.error('Error getting segments') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def segment_contains(self, segment_name, key): + """ + Check if segment contains a key + + :param segment_name: segment name + :type segment_name: str + :param key: key + :type key: str + + :return: True if found, otherwise False + :rtype: bool + """ + try: + return self._pluggable_adapter.item_contains(self._prefix.format(segment_name=segment_name), key) + + except Exception: + _LOGGER.error('Error checking segment key') + _LOGGER.debug('Error: ', exc_info=True) + return False + + def get(self, segment_name): + """ + Get a segment + + :param segment_name: segment name + :type segment_name: str + + :return: segment object + :rtype: splitio.models.segments.Segment + """ + try: + return segments.from_raw({'name': segment_name, 'added': self._pluggable_adapter.get_items(self._prefix.format(segment_name=segment_name)), 'removed': [], 'till': self._pluggable_adapter.get(self._segment_till_prefix.format(segment_name=segment_name))}) + + except Exception: + _LOGGER.error('Error getting segment') + _LOGGER.debug('Error: ', exc_info=True) + return None + +class PluggableSegmentStorageAsync(PluggableSegmentStorageBase): + """Pluggable async implementation of segment storage.""" + + def __init__(self, pluggable_adapter, prefix=None): + """ + Class constructor. + + :param pluggable_adapter: Storage client or compliant interface. + :type pluggable_adapter: TBD + :param prefix: optional, prefix to storage keys + :type prefix: str + """ + PluggableSegmentStorageBase.__init__(self, pluggable_adapter, prefix) + + async def get_change_number(self, segment_name): + """ + Get a segment change number. + + :param segment_name: segment name + :type segment_name: str + + :return: change number + :rtype: int + """ + try: + return await self._pluggable_adapter.get(self._segment_till_prefix.format(segment_name=segment_name)) + + except Exception: + _LOGGER.error('Error fetching segment change number') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def get_segment_names(self): + """ + Get list of segment names. + + :return: list of segment names + :rtype: str[] + """ + try: + keys = [] + for key in await self._pluggable_adapter.get_keys_by_prefix(self._prefix[:-self._SEGMENT_NAME_LENGTH]): + if key[-self._TILL_LENGTH:] != 'till': + keys.append(key[len(self._prefix[:-self._SEGMENT_NAME_LENGTH]):]) + return keys + + except Exception: + _LOGGER.error('Error getting segments') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def segment_contains(self, segment_name, key): + """ + Check if segment contains a key + + :param segment_name: segment name + :type segment_name: str + :param key: key + :type key: str + + :return: True if found, otherwise False + :rtype: bool + """ + try: + return await self._pluggable_adapter.item_contains(self._prefix.format(segment_name=segment_name), key) + + except Exception: + _LOGGER.error('Error checking segment key') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def get(self, segment_name): + """ + Get a segment + + :param segment_name: segment name + :type segment_name: str + + :return: segment object + :rtype: splitio.models.segments.Segment + """ + try: + return segments.from_raw({'name': segment_name, 'added': await self._pluggable_adapter.get_items(self._prefix.format(segment_name=segment_name)), 'removed': [], 'till': await self._pluggable_adapter.get(self._segment_till_prefix.format(segment_name=segment_name))}) + + except Exception: + _LOGGER.error('Error getting segment') + _LOGGER.debug('Error: ', exc_info=True) + return None + +class PluggableImpressionsStorageBase(ImpressionStorage): + """Pluggable Impressions storage class.""" + + IMPRESSIONS_KEY_DEFAULT_TTL = 3600 + + def __init__(self, pluggable_adapter, sdk_metadata, prefix=None): + """ + Class constructor. + + :param pluggable_adapter: Storage client or compliant interface. + :type pluggable_adapter: TBD + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + :param prefix: optional, prefix to storage keys + :type prefix: str + """ + self._pluggable_adapter = pluggable_adapter + self._sdk_metadata = { + 's': sdk_metadata.sdk_version, + 'n': sdk_metadata.instance_name, + 'i': sdk_metadata.instance_ip, + } + self._impressions_queue_key = 'SPLITIO.impressions' + if prefix is not None: + self._impressions_queue_key = prefix + "." + self._impressions_queue_key + + def _wrap_impressions(self, impressions): + """ + Wrap impressions to be stored in storage + + :param impressions: Impression to add to the queue. + :type impressions: splitio.models.impressions.Impression + + :return: Processed impressions. + :rtype: list[splitio.models.impressions.Impression] + """ + bulk_impressions = [] + for impression in impressions: + if isinstance(impression, Impression): + to_store = { + 'm': self._sdk_metadata, + 'i': { + 'k': impression.matching_key, + 'b': impression.bucketing_key, + 'f': impression.feature_name, + 't': impression.treatment, + 'r': impression.label, + 'c': impression.change_number, + 'm': impression.time, + 'properties': impression.properties + } + } + bulk_impressions.append(json.dumps(to_store)) + return bulk_impressions + + def put(self, impressions): + """ + Add an impression to the pluggable storage. + + :param impressions: Impression to add to the queue. + :type impressions: splitio.models.impressions.Impression + + :return: Whether the impression has been added or not. + :rtype: bool + """ + pass + + def expire_key(self, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + pass + + def pop_many(self, count): + """ + Pop the oldest N events from storage. + + :param count: Number of events to pop. + :type count: int + """ + raise NotImplementedError('Only consumer mode is supported.') + + def clear(self): + """ + Clear data. + """ + raise NotImplementedError('Only consumer mode is supported.') + + +class PluggableImpressionsStorage(PluggableImpressionsStorageBase): + """Pluggable Impressions storage class.""" + + def __init__(self, pluggable_adapter, sdk_metadata, prefix=None): + """ + Class constructor. + + :param pluggable_adapter: Storage client or compliant interface. + :type pluggable_adapter: TBD + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + :param prefix: optional, prefix to storage keys + :type prefix: str + """ + PluggableImpressionsStorageBase.__init__(self, pluggable_adapter, sdk_metadata, prefix) + + def put(self, impressions): + """ + Add an impression to the pluggable storage. + + :param impressions: Impression to add to the queue. + :type impressions: splitio.models.impressions.Impression + + :return: Whether the impression has been added or not. + :rtype: bool + """ + bulk_impressions = self._wrap_impressions(impressions) + try: + total_keys = self._pluggable_adapter.push_items(self._impressions_queue_key, *bulk_impressions) + self.expire_key(total_keys, len(bulk_impressions)) + return True + + except Exception: + _LOGGER.error('Something went wrong when trying to add impression to storage') + _LOGGER.error('Error: ', exc_info=True) + return False + + def expire_key(self, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + self._pluggable_adapter.expire(self._impressions_queue_key, self.IMPRESSIONS_KEY_DEFAULT_TTL) + + +class PluggableImpressionsStorageAsync(PluggableImpressionsStorageBase): + """Pluggable Impressions storage class.""" + + def __init__(self, pluggable_adapter, sdk_metadata, prefix=None): + """ + Class constructor. + + :param pluggable_adapter: Storage client or compliant interface. + :type pluggable_adapter: TBD + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + :param prefix: optional, prefix to storage keys + :type prefix: str + """ + PluggableImpressionsStorageBase.__init__(self, pluggable_adapter, sdk_metadata, prefix) + + async def put(self, impressions): + """ + Add an impression to the pluggable storage. + + :param impressions: Impression to add to the queue. + :type impressions: splitio.models.impressions.Impression + + :return: Whether the impression has been added or not. + :rtype: bool + """ + bulk_impressions = self._wrap_impressions(impressions) + try: + total_keys = await self._pluggable_adapter.push_items(self._impressions_queue_key, *bulk_impressions) + await self.expire_key(total_keys, len(bulk_impressions)) + return True + + except Exception: + _LOGGER.error('Something went wrong when trying to add impression to storage') + _LOGGER.error('Error: ', exc_info=True) + return False + + async def expire_key(self, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + await self._pluggable_adapter.expire(self._impressions_queue_key, self.IMPRESSIONS_KEY_DEFAULT_TTL) + + +class PluggableEventsStorageBase(EventStorage): + """Pluggable Event storage class.""" + + _EVENTS_KEY_DEFAULT_TTL = 3600 + + def __init__(self, pluggable_adapter, sdk_metadata, prefix=None): + """ + Class constructor. + + :param pluggable_adapter: Storage client or compliant interface. + :type pluggable_adapter: TBD + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + :param prefix: optional, prefix to storage keys + :type prefix: str + """ + self._pluggable_adapter = pluggable_adapter + self._sdk_metadata = { + 's': sdk_metadata.sdk_version, + 'n': sdk_metadata.instance_name, + 'i': sdk_metadata.instance_ip, + } + self._events_queue_key = 'SPLITIO.events' + if prefix is not None: + self._events_queue_key = prefix + "." + self._events_queue_key + + def _wrap_events(self, events): + return [ + json.dumps({ + 'e': { + 'key': e.event.key, + 'trafficTypeName': e.event.traffic_type_name, + 'eventTypeId': e.event.event_type_id, + 'value': e.event.value, + 'timestamp': e.event.timestamp, + 'properties': e.event.properties, + }, + 'm': self._sdk_metadata + }) + for e in events + ] + + def put(self, events): + """ + Add an event to the redis storage. + + :param event: Event to add to the queue. + :type event: splitio.models.events.Event + + :return: Whether the event has been added or not. + :rtype: bool + """ + pass + + def expire_key(self, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + pass + + def pop_many(self, count): + """ + Pop the oldest N events from storage. + + :param count: Number of events to pop. + :type count: int + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def clear(self): + """ + Clear data. + """ + raise NotImplementedError('Not supported for redis.') + +class PluggableEventsStorage(PluggableEventsStorageBase): + """Pluggable Event storage class.""" + + def __init__(self, pluggable_adapter, sdk_metadata, prefix=None): + """ + Class constructor. + + :param pluggable_adapter: Storage client or compliant interface. + :type pluggable_adapter: TBD + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + :param prefix: optional, prefix to storage keys + :type prefix: str + """ + PluggableEventsStorageBase.__init__(self, pluggable_adapter, sdk_metadata, prefix) + + def put(self, events): + """ + Add an event to the redis storage. + + :param event: Event to add to the queue. + :type event: splitio.models.events.Event + + :return: Whether the event has been added or not. + :rtype: bool + """ + to_store = self._wrap_events(events) + try: + total_keys = self._pluggable_adapter.push_items(self._events_queue_key, *to_store) + self.expire_key(total_keys, len(to_store)) + return True + + except Exception: + _LOGGER.error('Something went wrong when trying to add event to redis') + _LOGGER.debug('Error: ', exc_info=True) + return False + + def expire_key(self, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + self._pluggable_adapter.expire(self._events_queue_key, self._EVENTS_KEY_DEFAULT_TTL) + + +class PluggableEventsStorageAsync(PluggableEventsStorageBase): + """Pluggable Event storage class.""" + + def __init__(self, pluggable_adapter, sdk_metadata, prefix=None): + """ + Class constructor. + + :param pluggable_adapter: Storage client or compliant interface. + :type pluggable_adapter: TBD + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + :param prefix: optional, prefix to storage keys + :type prefix: str + """ + PluggableEventsStorageBase.__init__(self, pluggable_adapter, sdk_metadata, prefix) + + async def put(self, events): + """ + Add an event to the redis storage. + + :param event: Event to add to the queue. + :type event: splitio.models.events.Event + + :return: Whether the event has been added or not. + :rtype: bool + """ + to_store = self._wrap_events(events) + try: + total_keys = await self._pluggable_adapter.push_items(self._events_queue_key, *to_store) + await self.expire_key(total_keys, len(to_store)) + return True + + except Exception: + _LOGGER.error('Something went wrong when trying to add event to redis') + _LOGGER.debug('Error: ', exc_info=True) + return False + + async def expire_key(self, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + await self._pluggable_adapter.expire(self._events_queue_key, self._EVENTS_KEY_DEFAULT_TTL) + + +class PluggableTelemetryStorageBase(TelemetryStorage): + """Pluggable telemetry storage class.""" + + _TELEMETRY_KEY_DEFAULT_TTL = 3600 + + def _reset_config_tags(self): + """Reset config tags.""" + pass + + def add_config_tag(self, tag): + """ + Record tag string. + + :param tag: tag to be added + :type tag: str + """ + pass + + def record_config(self, config, extra_config, total_flag_sets, invalid_flag_sets): + """ + initilize telemetry objects + + :param config: factory configuration parameters + :type config: Dict + :param extra_config: any extra configs + :type extra_config: Dict + """ + pass + + def pop_config_tags(self): + """Get and reset configs.""" + pass + + def push_config_stats(self): + """push config stats to storage.""" + pass + + def _format_config_stats(self): + """format only selected config stats to json""" + pass + + def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """ + Record active and redundant factories. + + :param active_factory_count: active factory count + :type active_factory_count: int + :param redundant_factory_count: redundant factory count + :type redundant_factory_count: int + """ + pass + + def record_latency(self, method, bucket): + """ + record latency data + + :param method: method name + :type method: string + :param latency: latency + :type latency: int64 + """ + pass + + def record_exception(self, method): + """ + record an exception + + :param method: method name + :type method: string + """ + pass + + def record_not_ready_usage(self): + """Not implemented""" + pass + + def record_bur_time_out(self): + """Not implemented""" + pass + + def record_impression_stats(self, data_type, count): + """Not implemented""" + pass + + def expire_latency_keys(self, total_keys, inserted): + """ + Set expire ttl for a latency key in storage + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + pass + + def expire_keys(self, queue_key, key_default_ttl, total_keys, inserted): + """ + Set expire ttl for a key in storage if total keys equal inserted + + :param queue_keys: key to be set + :type queue_keys: str + :param ey_default_ttl: ttl value + :type ey_default_ttl: int + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + pass + + +class PluggableTelemetryStorage(PluggableTelemetryStorageBase): + """Pluggable telemetry storage class.""" + + def __init__(self, pluggable_adapter, sdk_metadata, prefix=None): + """ + Class constructor. + + :param pluggable_adapter: Storage client or compliant interface. + :type pluggable_adapter: TBD + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + :param prefix: optional, prefix to storage keys + :type prefix: str + """ + self._pluggable_adapter = pluggable_adapter + self._sdk_metadata = sdk_metadata.sdk_version + '/' + sdk_metadata.instance_name + '/' + sdk_metadata.instance_ip + self._telemetry_config_key = 'SPLITIO.telemetry.init' + self._telemetry_latencies_key = 'SPLITIO.telemetry.latencies' + self._telemetry_exceptions_key = 'SPLITIO.telemetry.exceptions' + if prefix is not None: + self._telemetry_config_key = prefix + "." + self._telemetry_config_key + self._telemetry_latencies_key = prefix + "." + self._telemetry_latencies_key + self._telemetry_exceptions_key = prefix + "." + self._telemetry_exceptions_key + + self._lock = threading.RLock() + self._reset_config_tags() + self._method_latencies = MethodLatencies() + self._method_exceptions = MethodExceptions() + self._tel_config = TelemetryConfig() + + def _reset_config_tags(self): + """Reset config tags.""" + with self._lock: + self._config_tags = [] + + def add_config_tag(self, tag): + """ + Record tag string. + + :param tag: tag to be added + :type tag: str + """ + with self._lock: + if len(self._config_tags) < MAX_TAGS: + self._config_tags.append(tag) + + def record_config(self, config, extra_config, total_flag_sets, invalid_flag_sets): + """ + initilize telemetry objects + + :param config: factory configuration parameters + :type config: Dict + :param extra_config: any extra configs + :type extra_config: Dict + """ + self._tel_config.record_config(config, extra_config, total_flag_sets, invalid_flag_sets) + + def pop_config_tags(self): + """Get and reset configs.""" + with self._lock: + tags = self._config_tags + self._reset_config_tags() + return tags + + def push_config_stats(self): + """push config stats to storage.""" + self._pluggable_adapter.set(self._telemetry_config_key + "::" + self._sdk_metadata, str(self._format_config_stats())) + + def _format_config_stats(self): + """format only selected config stats to json""" + config_stats = self._tel_config.get_stats() + return json.dumps({ + 'aF': config_stats['aF'], + 'rF': config_stats['rF'], + 'sT': config_stats['sT'], + 'oM': config_stats['oM'], + 't': self.pop_config_tags() + }) + + def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """ + Record active and redundant factories. + + :param active_factory_count: active factory count + :type active_factory_count: int + :param redundant_factory_count: redundant factory count + :type redundant_factory_count: int + """ + self._tel_config.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + + def record_latency(self, method, bucket): + """ + record latency data + + :param method: method name + :type method: string + :param latency: latency + :type latency: int64 + """ + latency_key = self._telemetry_latencies_key + '::' + self._sdk_metadata + '/' + method.value + '/' + str(bucket) + result = self._pluggable_adapter.increment(latency_key, 1) + self.expire_keys(latency_key, self._TELEMETRY_KEY_DEFAULT_TTL, 1, result) + + def record_exception(self, method): + """ + record an exception + + :param method: method name + :type method: string + """ + except_key = self._telemetry_exceptions_key + "::" + self._sdk_metadata + '/' + method.value + result = self._pluggable_adapter.increment(except_key, 1) + self.expire_keys(except_key, self._TELEMETRY_KEY_DEFAULT_TTL, 1, result) + + def expire_latency_keys(self, total_keys, inserted): + """ + Set expire ttl for a latency key in storage + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + self.expire_keys(self._telemetry_latencies_key, self._TELEMETRY_KEY_DEFAULT_TTL, total_keys, inserted) + + def expire_keys(self, queue_key, key_default_ttl, total_keys, inserted): + """ + Set expire ttl for a key in storage if total keys equal inserted + + :param queue_keys: key to be set + :type queue_keys: str + :param ey_default_ttl: ttl value + :type ey_default_ttl: int + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + self._pluggable_adapter.expire(queue_key, key_default_ttl) + + def record_bur_time_out(self): + """record BUR timeouts""" + pass + + def record_ready_time(self, ready_time): + """Record ready time.""" + pass + + +class PluggableTelemetryStorageAsync(PluggableTelemetryStorageBase): + """Pluggable telemetry storage class.""" + + @classmethod + async def create(cls, pluggable_adapter, sdk_metadata, prefix=None): + """ + Class constructor. + + :param pluggable_adapter: Storage client or compliant interface. + :type pluggable_adapter: TBD + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + :param prefix: optional, prefix to storage keys + :type prefix: str + """ + self = cls() + self._pluggable_adapter = pluggable_adapter + self._sdk_metadata = sdk_metadata.sdk_version + '/' + sdk_metadata.instance_name + '/' + sdk_metadata.instance_ip + self._telemetry_config_key = 'SPLITIO.telemetry.init' + self._telemetry_latencies_key = 'SPLITIO.telemetry.latencies' + self._telemetry_exceptions_key = 'SPLITIO.telemetry.exceptions' + if prefix is not None: + self._telemetry_config_key = prefix + "." + self._telemetry_config_key + self._telemetry_latencies_key = prefix + "." + self._telemetry_latencies_key + self._telemetry_exceptions_key = prefix + "." + self._telemetry_exceptions_key + + self._lock = asyncio.Lock() + await self._reset_config_tags() + self._method_latencies = await MethodLatenciesAsync.create() + self._method_exceptions = await MethodExceptionsAsync.create() + self._tel_config = await TelemetryConfigAsync.create() + return self + + async def _reset_config_tags(self): + """Reset config tags.""" + async with self._lock: + self._config_tags = [] + + async def add_config_tag(self, tag): + """ + Record tag string. + + :param tag: tag to be added + :type tag: str + """ + async with self._lock: + if len(self._config_tags) < MAX_TAGS: + self._config_tags.append(tag) + + async def record_config(self, config, extra_config, total_flag_sets, invalid_flag_sets): + """ + initilize telemetry objects + + :param config: factory configuration parameters + :type config: Dict + :param extra_config: any extra configs + :type extra_config: Dict + """ + await self._tel_config.record_config(config, extra_config, total_flag_sets, invalid_flag_sets) + + async def pop_config_tags(self): + """Get and reset configs.""" + tags = self._config_tags + await self._reset_config_tags() + return tags + + async def push_config_stats(self): + """push config stats to storage.""" + await self._pluggable_adapter.set(self._telemetry_config_key + "::" + self._sdk_metadata, str(await self._format_config_stats())) + + async def _format_config_stats(self): + """format only selected config stats to json""" + config_stats = await self._tel_config.get_stats() + return json.dumps({ + 'aF': config_stats['aF'], + 'rF': config_stats['rF'], + 'sT': config_stats['sT'], + 'oM': config_stats['oM'], + 't': await self.pop_config_tags() + }) + + async def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """ + Record active and redundant factories. + + :param active_factory_count: active factory count + :type active_factory_count: int + :param redundant_factory_count: redundant factory count + :type redundant_factory_count: int + """ + await self._tel_config.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + + async def record_latency(self, method, bucket): + """ + record latency data + + :param method: method name + :type method: string + :param latency: latency + :type latency: int64 + """ + latency_key = self._telemetry_latencies_key + '::' + self._sdk_metadata + '/' + method.value + '/' + str(bucket) + result = await self._pluggable_adapter.increment(latency_key, 1) + await self.expire_keys(latency_key, self._TELEMETRY_KEY_DEFAULT_TTL, 1, result) + + async def record_exception(self, method): + """ + record an exception + + :param method: method name + :type method: string + """ + except_key = self._telemetry_exceptions_key + "::" + self._sdk_metadata + '/' + method.value + result = await self._pluggable_adapter.increment(except_key, 1) + await self.expire_keys(except_key, self._TELEMETRY_KEY_DEFAULT_TTL, 1, result) + + async def expire_latency_keys(self, total_keys, inserted): + """ + Set expire ttl for a latency key in storage + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + await self.expire_keys(self._telemetry_latencies_key, self._TELEMETRY_KEY_DEFAULT_TTL, total_keys, inserted) + + async def expire_keys(self, queue_key, key_default_ttl, total_keys, inserted): + """ + Set expire ttl for a key in storage if total keys equal inserted + + :param queue_keys: key to be set + :type queue_keys: str + :param ey_default_ttl: ttl value + :type ey_default_ttl: int + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + await self._pluggable_adapter.expire(queue_key, key_default_ttl) + + async def record_bur_time_out(self): + """record BUR timeouts""" + pass + + async def record_ready_time(self, ready_time): + """Record ready time.""" + pass + + async def record_not_ready_usage(self): + """Not implemented""" + pass + + async def record_impression_stats(self, data_type, count): + """Not implemented""" + pass diff --git a/splitio/storage/redis.py b/splitio/storage/redis.py index cdc79b29..b8fe27ad 100644 --- a/splitio/storage/redis.py +++ b/splitio/storage/redis.py @@ -1,26 +1,28 @@ """Redis storage module.""" import json import logging +import threading from splitio.models.impressions import Impression -from splitio.models import splits, segments +from splitio.models import splits, segments, rule_based_segments +from splitio.models.telemetry import TelemetryConfig, TelemetryConfigAsync from splitio.storage import SplitStorage, SegmentStorage, ImpressionStorage, EventStorage, \ - ImpressionPipelinedStorage + ImpressionPipelinedStorage, TelemetryStorage, FlagSetsFilter, RuleBasedSegmentsStorage from splitio.storage.adapters.redis import RedisAdapterException from splitio.storage.adapters.cache_trait import decorate as add_cache, DEFAULT_MAX_AGE - +from splitio.storage.adapters.cache_trait import LocalMemoryCache, LocalMemoryCacheAsync +from splitio.util.storage_helper import get_valid_flag_sets, combine_valid_flag_sets _LOGGER = logging.getLogger(__name__) - - -class RedisSplitStorage(SplitStorage): - """Redis-based storage for splits.""" - - _SPLIT_KEY = 'SPLITIO.split.{split_name}' - _SPLIT_TILL_KEY = 'SPLITIO.splits.till' - _TRAFFIC_TYPE_KEY = 'SPLITIO.trafficType.{traffic_type_name}' - - def __init__(self, redis_client, enable_caching=False, max_age=DEFAULT_MAX_AGE): +MAX_TAGS = 10 + +class RedisRuleBasedSegmentsStorage(RuleBasedSegmentsStorage): + """Redis-based storage for rule based segments.""" + + _RB_SEGMENT_KEY = 'SPLITIO.rbsegment.{segment_name}' + _RB_SEGMENT_TILL_KEY = 'SPLITIO.rbsegments.till' + + def __init__(self, redis_client): """ Class constructor. @@ -28,351 +30,1064 @@ def __init__(self, redis_client, enable_caching=False, max_age=DEFAULT_MAX_AGE): :type redis_client: splitio.storage.adapters.redis.RedisAdapter """ self._redis = redis_client - if enable_caching: - self.get = add_cache(lambda *p, **_: p[0], max_age)(self.get) - self.is_valid_traffic_type = add_cache(lambda *p, **_: p[0], max_age)(self.is_valid_traffic_type) # pylint: disable=line-too-long - self.fetch_many = add_cache(lambda *p, **_: frozenset(p[0]), max_age)(self.fetch_many) + self._pipe = self._redis.pipeline - def _get_key(self, split_name): + def _get_key(self, segment_name): """ - Use the provided split_name to build the appropriate redis key. + Use the provided feature_flag_name to build the appropriate redis key. - :param split_name: Name of the split to interact with in redis. - :type split_name: str + :param feature_flag_name: Name of the feature flag to interact with in redis. + :type feature_flag_name: str :return: Redis key. :rtype: str. """ - return self._SPLIT_KEY.format(split_name=split_name) + return self._RB_SEGMENT_KEY.format(segment_name=segment_name) + + def get(self, segment_name): + """ + Retrieve a rule based segment. - def _get_traffic_type_key(self, traffic_type_name): + :param segment_name: Name of the segment to fetch. + :type segment_name: str + + :rtype: str """ - Use the provided split_name to build the appropriate redis key. + try: + raw = self._redis.get(self._get_key(segment_name)) + _LOGGER.debug("Fetchting rule based segment [%s] from redis" % segment_name) + _LOGGER.debug(raw) + return rule_based_segments.from_raw(json.loads(raw)) if raw is not None else None - :param split_name: Name of the split to interact with in redis. - :type split_name: str + except RedisAdapterException: + _LOGGER.error('Error fetching rule based segment from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None - :return: Redis key. - :rtype: str. + def update(self, to_add, to_delete, new_change_number): """ - return self._TRAFFIC_TYPE_KEY.format(traffic_type_name=traffic_type_name) + Update rule based segment.. - def get(self, split_name): # pylint: disable=method-hidden + :param to_add: List of rule based segment. to add + :type to_add: list[splitio.models.rule_based_segments.RuleBasedSegment] + :param to_delete: List of rule based segment. to delete + :type to_delete: list[splitio.models.rule_based_segments.RuleBasedSegment] + :param new_change_number: New change number. + :type new_change_number: int """ - Retrieve a split. + raise NotImplementedError('Only redis-consumer mode is supported.') - :param split_name: Name of the feature to fetch. - :type split_name: str + def get_change_number(self): + """ + Retrieve latest rule based segment change number. - :return: A split object parsed from redis if the key exists. None otherwise - :rtype: splitio.models.splits.Split + :rtype: int """ try: - raw = self._redis.get(self._get_key(split_name)) - return splits.from_raw(json.loads(raw)) if raw is not None else None + stored_value = self._redis.get(self._RB_SEGMENT_TILL_KEY) + _LOGGER.debug("Fetching rule based segment Change Number from redis: %s" % stored_value) + return json.loads(stored_value) if stored_value is not None else None + except RedisAdapterException: - _LOGGER.error('Error fetching split from storage') + _LOGGER.error('Error fetching rule based segment change number from storage') _LOGGER.debug('Error: ', exc_info=True) return None + + def contains(self, segment_names): + """ + Return whether the segments exists in rule based segment in cache. + + :param segment_names: segment name to validate. + :type segment_names: str + + :return: True if segment names exists. False otherwise. + :rtype: bool + """ + return set(segment_names).issubset(self.get_segment_names()) + + def get_segment_names(self): + """ + Retrieve a list of all rule based segments names. + + :return: List of segment names. + :rtype: list(str) + """ + try: + keys = self._redis.keys(self._get_key('*')) + _LOGGER.debug("Fetchting rule based segments names from redis: %s" % keys) + return [key.replace(self._get_key(''), '') for key in keys] + + except RedisAdapterException: + _LOGGER.error('Error fetching rule based segments names from storage') + _LOGGER.debug('Error: ', exc_info=True) + return [] + + def get_large_segment_names(self): + """ + Retrieve a list of all excluded large segments names. - def fetch_many(self, split_names): + :return: List of segment names. + :rtype: list(str) + """ + pass + + def fetch_many(self, segment_names): """ - Retrieve splits. + Retrieve rule based segment. - :param split_names: Names of the features to fetch. - :type split_name: list(str) + :param segment_names: Names of the rule based segments to fetch. + :type segment_names: list(str) - :return: A dict with split objects parsed from redis. - :rtype: dict(split_name, splitio.models.splits.Split) + :return: A dict with rule based segment objects parsed from redis. + :rtype: dict(segment_name, splitio.models.rule_based_segment.RuleBasedSegment) """ to_return = dict() + if len(segment_names) == 0: + return to_return + try: - keys = [self._get_key(split_name) for split_name in split_names] - raw_splits = self._redis.mget(keys) - for i in range(len(split_names)): - split = None + keys = [self._get_key(segment_name) for segment_name in segment_names] + raw_rbs_segments = self._redis.mget(keys) + _LOGGER.debug("Fetchting rule based segment [%s] from redis" % segment_names) + _LOGGER.debug(raw_rbs_segments) + for i in range(len(raw_rbs_segments)): + rbs_segment = None try: - split = splits.from_raw(json.loads(raw_splits[i])) + rbs_segment = rule_based_segments.from_raw(json.loads(raw_rbs_segments[i])) except (ValueError, TypeError): - _LOGGER.error('Could not parse split.') - _LOGGER.debug("Raw split that failed parsing attempt: %s", raw_splits[i]) - to_return[split_names[i]] = split + _LOGGER.error('Could not parse rule based segment.') + _LOGGER.debug("Raw rule based segment that failed parsing attempt: %s", raw_rbs_segments[i]) + to_return[segment_names[i]] = rbs_segment except RedisAdapterException: - _LOGGER.error('Error fetching splits from storage') + _LOGGER.error('Error fetching rule based segments from storage') _LOGGER.debug('Error: ', exc_info=True) return to_return - def is_valid_traffic_type(self, traffic_type_name): # pylint: disable=method-hidden +class RedisRuleBasedSegmentsStorageAsync(RuleBasedSegmentsStorage): + """Redis-based storage for rule based segments.""" + + _RB_SEGMENT_KEY = 'SPLITIO.rbsegment.{segment_name}' + _RB_SEGMENT_TILL_KEY = 'SPLITIO.rbsegments.till' + + def __init__(self, redis_client): """ - Return whether the traffic type exists in at least one split in cache. - - :param traffic_type_name: Traffic type to validate. - :type traffic_type_name: str + Class constructor. - :return: True if the traffic type is valid. False otherwise. - :rtype: bool + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter """ - try: - raw = self._redis.get(self._get_traffic_type_key(traffic_type_name)) - count = json.loads(raw) if raw else 0 - return count > 0 - except RedisAdapterException: - _LOGGER.error('Error fetching split from storage') - _LOGGER.debug('Error: ', exc_info=True) - return False + self._redis = redis_client + self._pipe = self._redis.pipeline - def put(self, split): + def _get_key(self, segment_name): """ - Store a split. + Use the provided feature_flag_name to build the appropriate redis key. - :param split: Split object to store - :type split_name: splitio.models.splits.Split + :param feature_flag_name: Name of the feature flag to interact with in redis. + :type feature_flag_name: str + + :return: Redis key. + :rtype: str. """ - raise NotImplementedError('Only redis-consumer mode is supported.') + return self._RB_SEGMENT_KEY.format(segment_name=segment_name) + + async def get(self, segment_name): + """ + Retrieve a rule based segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str - def remove(self, split_name): + :rtype: str """ - Remove a split from storage. + try: + raw = await self._redis.get(self._get_key(segment_name)) + _LOGGER.debug("Fetchting rule based segment [%s] from redis" % segment_name) + _LOGGER.debug(raw) + return rule_based_segments.from_raw(json.loads(raw)) if raw is not None else None + + except RedisAdapterException: + _LOGGER.error('Error fetching rule based segment from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None - :param split_name: Name of the feature to remove. - :type split_name: str + async def update(self, to_add, to_delete, new_change_number): + """ + Update rule based segment.. - :return: True if the split was found and removed. False otherwise. - :rtype: bool + :param to_add: List of rule based segment. to add + :type to_add: list[splitio.models.rule_based_segments.RuleBasedSegment] + :param to_delete: List of rule based segment. to delete + :type to_delete: list[splitio.models.rule_based_segments.RuleBasedSegment] + :param new_change_number: New change number. + :type new_change_number: int """ raise NotImplementedError('Only redis-consumer mode is supported.') - def get_change_number(self): + async def get_change_number(self): """ - Retrieve latest split change number. + Retrieve latest rule based segment change number. :rtype: int """ try: - stored_value = self._redis.get(self._SPLIT_TILL_KEY) + stored_value = await self._redis.get(self._RB_SEGMENT_TILL_KEY) + _LOGGER.debug("Fetching rule based segment Change Number from redis: %s" % stored_value) return json.loads(stored_value) if stored_value is not None else None + except RedisAdapterException: - _LOGGER.error('Error fetching split change number from storage') + _LOGGER.error('Error fetching rule based segment change number from storage') _LOGGER.debug('Error: ', exc_info=True) return None - - def set_change_number(self, new_change_number): + + async def contains(self, segment_names): """ - Set the latest change number. + Return whether the segments exists in rule based segment in cache. - :param new_change_number: New change number. - :type new_change_number: int - """ - raise NotImplementedError('Only redis-consumer mode is supported.') + :param segment_names: segment name to validate. + :type segment_names: str - def get_split_names(self): + :return: True if segment names exists. False otherwise. + :rtype: bool """ - Retrieve a list of all split names. + return set(segment_names).issubset(await self.get_segment_names()) + + async def get_segment_names(self): + """ + Retrieve a list of all rule based segments names. - :return: List of split names. + :return: List of segment names. :rtype: list(str) """ try: - keys = self._redis.keys(self._get_key('*')) + keys = await self._redis.keys(self._get_key('*')) + _LOGGER.debug("Fetchting rule based segments names from redis: %s" % keys) return [key.replace(self._get_key(''), '') for key in keys] + except RedisAdapterException: - _LOGGER.error('Error fetching split names from storage') + _LOGGER.error('Error fetching rule based segments names from storage') _LOGGER.debug('Error: ', exc_info=True) return [] - def get_all_splits(self): + async def get_large_segment_names(self): """ - Return all the splits in cache. + Retrieve a list of all excluded large segments names. - :return: List of all splits in cache. - :rtype: list(splitio.models.splits.Split) + :return: List of segment names. + :rtype: list(str) """ - keys = self._redis.keys(self._get_key('*')) - to_return = [] + pass + + async def fetch_many(self, segment_names): + """ + Retrieve rule based segment. + + :param segment_names: Names of the rule based segments to fetch. + :type segment_names: list(str) + + :return: A dict with rule based segment objects parsed from redis. + :rtype: dict(segment_name, splitio.models.rule_based_segment.RuleBasedSegment) + """ + to_return = dict() + if len(segment_names) == 0: + return to_return + try: - raw_splits = self._redis.mget(keys) - for raw in raw_splits: + keys = [self._get_key(segment_name) for segment_name in segment_names] + raw_rbs_segments = await self._redis.mget(keys) + _LOGGER.debug("Fetchting rule based segment [%s] from redis" % segment_names) + _LOGGER.debug(raw_rbs_segments) + for i in range(len(raw_rbs_segments)): + rbs_segment = None try: - to_return.append(splits.from_raw(json.loads(raw))) + rbs_segment = rule_based_segments.from_raw(json.loads(raw_rbs_segments[i])) except (ValueError, TypeError): - _LOGGER.error('Could not parse split. Skipping') - _LOGGER.debug("Raw split that failed parsing attempt: %s", raw) + _LOGGER.error('Could not parse rule based segment.') + _LOGGER.debug("Raw rule based segment that failed parsing attempt: %s", raw_rbs_segments[i]) + to_return[segment_names[i]] = rbs_segment except RedisAdapterException: - _LOGGER.error('Error fetching all splits from storage') + _LOGGER.error('Error fetching rule based segments from storage') _LOGGER.debug('Error: ', exc_info=True) return to_return - def kill_locally(self, split_name, default_treatment, change_number): - """ - Local kill for split - - :param split_name: name of the split to perform kill - :type split_name: str - :param default_treatment: name of the default treatment to return - :type default_treatment: str - :param change_number: change_number - :type change_number: int - """ - raise NotImplementedError('Not supported for redis.') - - -class RedisSegmentStorage(SegmentStorage): - """Redis based segment storage class.""" +class RedisSplitStorageBase(SplitStorage): + """Redis-based storage base for feature flags.""" - _SEGMENTS_KEY = 'SPLITIO.segment.{segment_name}' - _SEGMENTS_TILL_KEY = 'SPLITIO.segment.{segment_name}.till' + _FEATURE_FLAG_KEY = 'SPLITIO.split.{feature_flag_name}' + _FEATURE_FLAG_TILL_KEY = 'SPLITIO.splits.till' + _TRAFFIC_TYPE_KEY = 'SPLITIO.trafficType.{traffic_type_name}' + _FLAG_SET_KEY = 'SPLITIO.flagSet.{flag_set}' - def __init__(self, redis_client): + def _get_key(self, feature_flag_name): """ - Class constructor. + Use the provided feature_flag_name to build the appropriate redis key. - :param redis_client: Redis client or compliant interface. - :type redis_client: splitio.storage.adapters.redis.RedisAdapter + :param feature_flag_name: Name of the feature flag to interact with in redis. + :type feature_flag_name: str + + :return: Redis key. + :rtype: str. """ - self._redis = redis_client + return self._FEATURE_FLAG_KEY.format(feature_flag_name=feature_flag_name) - def _get_till_key(self, segment_name): + def _get_traffic_type_key(self, traffic_type_name): """ - Use the provided segment_name to build the appropriate redis key. + Use the provided traffic type name to build the appropriate redis key. - :param segment_name: Name of the segment to interact with in redis. - :type segment_name: str + :param traffic_type: Name of the traffic type to interact with in redis. + :type traffic_type_name: str :return: Redis key. :rtype: str. """ - return self._SEGMENTS_TILL_KEY.format(segment_name=segment_name) + return self._TRAFFIC_TYPE_KEY.format(traffic_type_name=traffic_type_name) - def _get_key(self, segment_name): + def _get_flag_set_key(self, flag_set): """ - Use the provided segment_name to build the appropriate redis key. - - :param segment_name: Name of the segment to interact with in redis. - :type segment_name: str - + Use the provided flag set to build the appropriate redis key. + :param flag_set: Name of the flag set to interact with in redis. + :type flag_set: str :return: Redis key. :rtype: str. """ - return self._SEGMENTS_KEY.format(segment_name=segment_name) + return self._FLAG_SET_KEY.format(flag_set=flag_set) - def get(self, segment_name): + def get(self, feature_flag_name): # pylint: disable=method-hidden """ - Retrieve a segment. + Retrieve a feature flag. - :param segment_name: Name of the segment to fetch. - :type segment_name: str + :param feature_flag_name: Name of the feature to fetch. + :type feature_flag_name: str - :return: Segment object is key exists. None otherwise. - :rtype: splitio.models.segments.Segment + :return: A feature flag object parsed from redis if the key exists. None otherwise + :rtype: splitio.models.splits.Split """ - try: - keys = (self._redis.smembers(self._get_key(segment_name))) - till = self.get_change_number(segment_name) - if not keys or till is None: - return None - return segments.Segment(segment_name, keys, till) - except RedisAdapterException: - _LOGGER.error('Error fetching segment from storage') - _LOGGER.debug('Error: ', exc_info=True) - return None + pass - def update(self, segment_name, to_add, to_remove, change_number=None): + def fetch_many(self, feature_flag_names): """ - Store a split. + Retrieve feature flags. - :param segment_name: Name of the segment to update. - :type segment_name: str - :param to_add: List of members to add to the segment. - :type to_add: list - :param to_remove: List of members to remove from the segment. - :type to_remove: list + :param feature_flag_names: Names of the features to fetch. + :type feature_flag_name: list(str) + + :return: A dict with feature flag objects parsed from redis. + :rtype: dict(feature_flag_name, splitio.models.splits.Split) """ - raise NotImplementedError('Only redis-consumer mode is supported.') + pass - def get_change_number(self, segment_name): + def is_valid_traffic_type(self, traffic_type_name): # pylint: disable=method-hidden """ - Retrieve latest change number for a segment. + Return whether the traffic type exists in at least one feature flag in cache. - :param segment_name: Name of the segment. - :type segment_name: str + :param traffic_type_name: Traffic type to validate. + :type traffic_type_name: str - :rtype: int + :return: True if the traffic type is valid. False otherwise. + :rtype: bool """ - try: - stored_value = self._redis.get(self._get_till_key(segment_name)) - return json.loads(stored_value) if stored_value is not None else None - except RedisAdapterException: - _LOGGER.error('Error fetching segment change number from storage') - _LOGGER.debug('Error: ', exc_info=True) - return None + pass - def set_change_number(self, segment_name, new_change_number): + def update(self, to_add, to_delete, new_change_number): """ - Set the latest change number. + Update feature flag storage. - :param segment_name: Name of the segment. - :type segment_name: str + :param to_add: List of feature flags to add + :type to_add: list[splitio.models.splits.Split] + :param to_delete: List of feature flags to delete + :type to_delete: list[splitio.models.splits.Split] :param new_change_number: New change number. :type new_change_number: int """ raise NotImplementedError('Only redis-consumer mode is supported.') - def put(self, segment): + def get_change_number(self): """ - Store a segment. + Retrieve latest feature flag change number. - :param segment: Segment to store. - :type segment: splitio.models.segment.Segment + :rtype: int """ - raise NotImplementedError('Only redis-consumer mode is supported.') + pass - def segment_contains(self, segment_name, key): + def get_split_names(self): """ - Check whether a specific key belongs to a segment in storage. + Retrieve a list of all feature flag names. - :param segment_name: Name of the segment to search in. - :type segment_name: str - :param key: Key to search for. - :type key: str + :return: List of feature flag names. + :rtype: list(str) + """ + pass - :return: True if the segment contains the key. False otherwise. - :rtype: bool + def get_splits_count(self): """ - try: - return self._redis.sismember(self._get_key(segment_name), key) - except RedisAdapterException: - _LOGGER.error('Error testing members in segment stored in redis') - _LOGGER.debug('Error: ', exc_info=True) - return None + Return feature flags count. + + :rtype: int + """ + return 0 + def get_all_splits(self): + """ + Return all the feature flags in cache. + :return: List of all feature flags in cache. + :rtype: list(splitio.models.splits.Split) + """ + pass -class RedisImpressionsStorage(ImpressionStorage, ImpressionPipelinedStorage): - """Redis based event storage class.""" + def kill_locally(self, feature_flag_name, default_treatment, change_number): + """ + Local kill for feature flag - IMPRESSIONS_QUEUE_KEY = 'SPLITIO.impressions' - IMPRESSIONS_KEY_DEFAULT_TTL = 3600 + :param feature_flag_name: name of the feature flag to perform kill + :type feature_flag_name: str + :param default_treatment: name of the default treatment to return + :type default_treatment: str + :param change_number: change_number + :type change_number: int + """ + raise NotImplementedError('Not supported for redis.') - def __init__(self, redis_client, sdk_metadata): + +class RedisSplitStorage(RedisSplitStorageBase): + """Redis-based storage for feature flags.""" + + def __init__(self, redis_client, enable_caching=False, max_age=DEFAULT_MAX_AGE, config_flag_sets=[]): """ Class constructor. :param redis_client: Redis client or compliant interface. :type redis_client: splitio.storage.adapters.redis.RedisAdapter - :param sdk_metadata: SDK & Machine information. - :type sdk_metadata: splitio.client.util.SdkMetadata """ self._redis = redis_client - self._sdk_metadata = sdk_metadata + self.flag_set_filter = FlagSetsFilter(config_flag_sets) + self._pipe = self._redis.pipeline + if enable_caching: + self.get = add_cache(lambda *p, **_: p[0], max_age)(self.get) + self.is_valid_traffic_type = add_cache(lambda *p, **_: p[0], max_age)(self.is_valid_traffic_type) # pylint: disable=line-too-long + self.fetch_many = add_cache(lambda *p, **_: frozenset(p[0]), max_age)(self.fetch_many) - def _wrap_impressions(self, impressions): + def get(self, feature_flag_name): # pylint: disable=method-hidden """ - Wrap impressions to be stored in redis + Retrieve a feature flag. - :param impressions: Impression to add to the queue. - :type impressions: splitio.models.impressions.Impression + :param feature_flag_name: Name of the feature to fetch. + :type feature_flag_name: str - :return: Processed impressions. - :rtype: list[splitio.models.impressions.Impression] + :return: A feature flag object parsed from redis if the key exists. None otherwise + :rtype: splitio.models.splits.Split + """ + try: + raw = self._redis.get(self._get_key(feature_flag_name)) + _LOGGER.debug("Fetchting feature flag [%s] from redis" % feature_flag_name) + _LOGGER.debug(raw) + return splits.from_raw(json.loads(raw)) if raw is not None else None + + except RedisAdapterException: + _LOGGER.error('Error fetching feature flag from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def get_feature_flags_by_sets(self, flag_sets): + """ + Retrieve feature flags by flag set. + :param flag_set: Names of the flag set to fetch. + :type flag_set: str + :return: Feature flag names that are tagged with the flag set + :rtype: listt(str) + """ + try: + sets_to_fetch = get_valid_flag_sets(flag_sets, self.flag_set_filter) + if sets_to_fetch == []: + return [] + + keys = [self._get_flag_set_key(flag_set) for flag_set in sets_to_fetch] + pipe = self._pipe() + for key in keys: + pipe.smembers(key) + result_sets = pipe.execute() + _LOGGER.debug("Fetchting Feature flags by set [%s] from redis" % (keys)) + _LOGGER.debug(result_sets) + return list(combine_valid_flag_sets(result_sets)) + + except RedisAdapterException: + _LOGGER.error('Error fetching feature flag from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def fetch_many(self, feature_flag_names): + """ + Retrieve feature flags. + + :param feature_flag_names: Names of the features to fetch. + :type feature_flag_name: list(str) + + :return: A dict with feature flag objects parsed from redis. + :rtype: dict(feature_flag_name, splitio.models.splits.Split) + """ + to_return = dict() + try: + keys = [self._get_key(feature_flag_name) for feature_flag_name in feature_flag_names] + raw_feature_flags = self._redis.mget(keys) + _LOGGER.debug("Fetchting feature flags [%s] from redis" % feature_flag_names) + _LOGGER.debug(raw_feature_flags) + for i in range(len(feature_flag_names)): + feature_flag = None + try: + feature_flag = splits.from_raw(json.loads(raw_feature_flags[i])) + except (ValueError, TypeError): + _LOGGER.error('Could not parse feature flag.') + _LOGGER.debug("Raw feature flag that failed parsing attempt: %s", raw_feature_flags[i]) + to_return[feature_flag_names[i]] = feature_flag + except RedisAdapterException: + _LOGGER.error('Error fetching feature flags from storage') + _LOGGER.debug('Error: ', exc_info=True) + return to_return + + def is_valid_traffic_type(self, traffic_type_name): # pylint: disable=method-hidden + """ + Return whether the traffic type exists in at least one feature flag in cache. + + :param traffic_type_name: Traffic type to validate. + :type traffic_type_name: str + + :return: True if the traffic type is valid. False otherwise. + :rtype: bool + """ + try: + raw = self._redis.get(self._get_traffic_type_key(traffic_type_name)) + count = json.loads(raw) if raw else 0 + _LOGGER.debug("Fetching TrafficType [%s] count in redis: %s" % (traffic_type_name, count)) + return count > 0 + + except RedisAdapterException: + _LOGGER.error('Error fetching feature flag from storage') + _LOGGER.debug('Error: ', exc_info=True) + return False + + def get_change_number(self): + """ + Retrieve latest feature flag change number. + + :rtype: int + """ + try: + stored_value = self._redis.get(self._FEATURE_FLAG_TILL_KEY) + _LOGGER.debug("Fetching feature flag Change Number from redis: %s" % stored_value) + return json.loads(stored_value) if stored_value is not None else None + + except RedisAdapterException: + _LOGGER.error('Error fetching feature flag change number from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def get_split_names(self): + """ + Retrieve a list of all feature flag names. + + :return: List of feature flag names. + :rtype: list(str) + """ + try: + keys = self._redis.keys(self._get_key('*')) + _LOGGER.debug("Fetchting feature flag names from redis: %s" % keys) + return [key.replace(self._get_key(''), '') for key in keys] + + except RedisAdapterException: + _LOGGER.error('Error fetching feature flag names from storage') + _LOGGER.debug('Error: ', exc_info=True) + return [] + + def get_all_splits(self): + """ + Return all the feature flags in cache. + :return: List of all feature flags in cache. + :rtype: list(splitio.models.splits.Split) + """ + keys = self._redis.keys(self._get_key('*')) + to_return = [] + try: + _LOGGER.debug("Fetchting all feature flags from redis: %s" % keys) + raw_feature_flags = self._redis.mget(keys) + _LOGGER.debug(raw_feature_flags) + for raw in raw_feature_flags: + try: + to_return.append(splits.from_raw(json.loads(raw))) + except (ValueError, TypeError): + _LOGGER.error('Could not parse feature flag. Skipping') + _LOGGER.debug("Raw feature flag that failed parsing attempt: %s", raw) + except RedisAdapterException: + _LOGGER.error('Error fetching all feature flags from storage') + _LOGGER.debug('Error: ', exc_info=True) + return to_return + +class RedisSplitStorageAsync(RedisSplitStorage): + """Async Redis-based storage for feature flags.""" + + def __init__(self, redis_client, enable_caching=False, max_age=DEFAULT_MAX_AGE, config_flag_sets=[]): + """ + Class constructor. + """ + self.redis = redis_client + self._enable_caching = enable_caching + self.flag_set_filter = FlagSetsFilter(config_flag_sets) + self._pipe = self.redis.pipeline + if enable_caching: + self._feature_flag_cache = LocalMemoryCacheAsync(None, None, max_age) + self._traffic_type_cache = LocalMemoryCacheAsync(None, None, max_age) + + + async def get(self, feature_flag_name): # pylint: disable=method-hidden + """ + Retrieve a feature flag. + :param feature_flag_name: Name of the feature to fetch. + :type feature_flag_name: str + + :param default_treatment: name of the default treatment to return + :type default_treatment: str + return: A feature flag object parsed from redis if the key exists. None otherwise + + :param change_number: change_number + :rtype: splitio.models.splits.Split + :type change_number: int + """ + try: + raw_feature_flags = None + if self._enable_caching: + raw_feature_flags = await self._feature_flag_cache.get_key(feature_flag_name) + if raw_feature_flags is None: + raw_feature_flags = await self.redis.get(self._get_key(feature_flag_name)) + if self._enable_caching: + await self._feature_flag_cache.add_key(feature_flag_name, raw_feature_flags) + _LOGGER.debug("Fetchting feature flag [%s] from redis" % feature_flag_name) + _LOGGER.debug(raw_feature_flags) + return splits.from_raw(json.loads(raw_feature_flags)) if raw_feature_flags is not None else None + + except RedisAdapterException: + _LOGGER.error('Error fetching feature flag from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def get_feature_flags_by_sets(self, flag_sets): + """ + Retrieve feature flags by flag set. + :param flag_set: Names of the flag set to fetch. + :type flag_set: str + :return: Feature flag names that are tagged with the flag set + :rtype: listt(str) + """ + try: + sets_to_fetch = get_valid_flag_sets(flag_sets, self.flag_set_filter) + if sets_to_fetch == []: + return [] + + keys = [self._get_flag_set_key(flag_set) for flag_set in sets_to_fetch] + pipe = self._pipe() + [pipe.smembers(key) for key in keys] + result_sets = await pipe.execute() + _LOGGER.debug("Fetchting Feature flags by set [%s] from redis" % (keys)) + _LOGGER.debug(result_sets) + return list(combine_valid_flag_sets(result_sets)) + + except RedisAdapterException: + _LOGGER.error('Error fetching feature flag from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def fetch_many(self, feature_flag_names): + """ + Retrieve feature flags. + :param feature_flag_names: Names of the features to fetch. + :type feature_flag_name: list(str) + :return: A dict with feature flag objects parsed from redis. + :rtype: dict(feature_flag_name, splitio.models.splits.Split) + """ + to_return = dict() + try: + raw_feature_flags = None + if self._enable_caching: + raw_feature_flags = await self._feature_flag_cache.get_key(frozenset(feature_flag_names)) + if raw_feature_flags is None: + raw_feature_flags = await self.redis.mget([self._get_key(feature_flag_name) for feature_flag_name in feature_flag_names]) + if self._enable_caching: + await self._feature_flag_cache.add_key(frozenset(feature_flag_names), raw_feature_flags) + for i in range(len(feature_flag_names)): + feature_flag = None + try: + feature_flag = splits.from_raw(json.loads(raw_feature_flags[i])) + except (ValueError, TypeError): + _LOGGER.error('Could not parse feature flag.') + _LOGGER.debug("Raw feature flag that failed parsing attempt: %s", raw_feature_flags[i]) + to_return[feature_flag_names[i]] = feature_flag + except RedisAdapterException: + _LOGGER.error('Error fetching feature flags from storage') + _LOGGER.debug('Error: ', exc_info=True) + return to_return + + async def is_valid_traffic_type(self, traffic_type_name): # pylint: disable=method-hidden + """ + Return whether the traffic type exists in at least one feature flag in cache. + :param traffic_type_name: Traffic type to validate. + :type traffic_type_name: str + :return: True if the traffic type is valid. False otherwise. + :rtype: bool + """ + try: + raw_traffic_type = None + if self._enable_caching: + raw_traffic_type = await self._traffic_type_cache.get_key(traffic_type_name) + if raw_traffic_type is None: + raw_traffic_type = await self.redis.get(self._get_traffic_type_key(traffic_type_name)) + if self._enable_caching: + await self._traffic_type_cache.add_key(traffic_type_name, raw_traffic_type) + count = json.loads(raw_traffic_type) if raw_traffic_type else 0 + return count > 0 + + except RedisAdapterException: + _LOGGER.error('Error fetching traffic type from storage') + _LOGGER.debug('Error: ', exc_info=True) + return False + + async def get_change_number(self): + """ + Retrieve latest feature flag change number. + :rtype: int + """ + try: + stored_value = await self.redis.get(self._FEATURE_FLAG_TILL_KEY) + return json.loads(stored_value) if stored_value is not None else None + + except RedisAdapterException: + _LOGGER.error('Error fetching feature flag change number from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def get_split_names(self): + """ + Retrieve a list of all feature flag names. + :return: List of feature flag names. + :rtype: list(str) + """ + try: + keys = await self.redis.keys(self._get_key('*')) + return [key.replace(self._get_key(''), '') for key in keys] + + except RedisAdapterException: + _LOGGER.error('Error fetching feature flag names from storage') + _LOGGER.debug('Error: ', exc_info=True) + return [] + + async def get_all_splits(self): + """ + Return all the feature flags in cache. + :return: List of all feature flags in cache. + :rtype: list(splitio.models.splits.Split) + """ + keys = await self.redis.keys(self._get_key('*')) + to_return = [] + try: + raw_feature_flags = await self.redis.mget(keys) + for raw in raw_feature_flags: + try: + to_return.append(splits.from_raw(json.loads(raw))) + except (ValueError, TypeError): + _LOGGER.error('Could not parse feature flag. Skipping') + _LOGGER.debug("Raw feature flag that failed parsing attempt: %s", raw) + except RedisAdapterException: + _LOGGER.error('Error fetching all feature flags from storage') + _LOGGER.debug('Error: ', exc_info=True) + return to_return + + +class RedisSegmentStorageBase(SegmentStorage): + """Redis based segment storage base class.""" + + _SEGMENTS_KEY = 'SPLITIO.segment.{segment_name}' + _SEGMENTS_TILL_KEY = 'SPLITIO.segment.{segment_name}.till' + + def _get_till_key(self, segment_name): + """ + Use the provided segment_name to build the appropriate redis key. + + :param segment_name: Name of the segment to interact with in redis. + :type segment_name: str + + :return: Redis key. + :rtype: str. + """ + return self._SEGMENTS_TILL_KEY.format(segment_name=segment_name) + + def _get_key(self, segment_name): + """ + Use the provided segment_name to build the appropriate redis key. + + :param segment_name: Name of the segment to interact with in redis. + :type segment_name: str + + :return: Redis key. + :rtype: str. + """ + return self._SEGMENTS_KEY.format(segment_name=segment_name) + + def get(self, segment_name): + """Retrieve a segment.""" + pass + + def update(self, segment_name, to_add, to_remove, change_number=None): + """ + Store a segment. + + :param segment_name: Name of the segment to update. + :type segment_name: str + :param to_add: List of members to add to the segment. + :type to_add: list + :param to_remove: List of members to remove from the segment. + :type to_remove: list + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def get_change_number(self, segment_name): + """ + Retrieve latest change number for a segment. + + :param segment_name: Name of the segment. + :type segment_name: str + + :rtype: int + """ + pass + + def set_change_number(self, segment_name, new_change_number): + """ + Set the latest change number. + + :param segment_name: Name of the segment. + :type segment_name: str + :param new_change_number: New change number. + :type new_change_number: int + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def put(self, segment): + """ + Store a segment. + + :param segment: Segment to store. + :type segment: splitio.models.segment.Segment + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def segment_contains(self, segment_name, key): + """ + Check whether a specific key belongs to a segment in storage. + + :param segment_name: Name of the segment to search in. + :type segment_name: str + :param key: Key to search for. + :type key: str + + :return: True if the segment contains the key. False otherwise. + :rtype: bool + """ + try: + res = self._redis.sismember(self._get_key(segment_name), key) + _LOGGER.debug("Checking Segment [%s] contain key [%s] in redis: %s" % (segment_name, key, res)) + return bool(res) + except RedisAdapterException: + _LOGGER.error('Error testing members in segment stored in redis') + _LOGGER.debug('Error: ', exc_info=True) + return False + + def get_segments_count(self): + """ + Return segment count. + + :return: 0 + :rtype: int + """ + return 0 + + def get_segments_keys_count(self): + """ + Return segment count. + + :rtype: int + """ + return 0 + + +class RedisSegmentStorage(RedisSegmentStorageBase): + """Redis based segment storage class.""" + + def __init__(self, redis_client): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + """ + self._redis = redis_client + + def get(self, segment_name): + """ + Retrieve a segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str + + :return: Segment object is key exists. None otherwise. + :rtype: splitio.models.segments.Segment + """ + try: + keys = (self._redis.smembers(self._get_key(segment_name))) + _LOGGER.debug("Fetchting Segment [%s] from redis" % segment_name) + _LOGGER.debug(keys) + till = self.get_change_number(segment_name) + if not keys or till is None: + return None + + return segments.Segment(segment_name, keys, till) + + except RedisAdapterException: + _LOGGER.error('Error fetching segment from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def get_change_number(self, segment_name): + """ + Retrieve latest change number for a segment. + + :param segment_name: Name of the segment. + :type segment_name: str + + :rtype: int + """ + try: + stored_value = self._redis.get(self._get_till_key(segment_name)) + _LOGGER.debug("Fetchting Change Number for Segment [%s] from redis: " % stored_value) + return json.loads(stored_value) if stored_value is not None else None + + except RedisAdapterException: + _LOGGER.error('Error fetching segment change number from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + def segment_contains(self, segment_name, key): + """ + Check whether a specific key belongs to a segment in storage. + + :param segment_name: Name of the segment to search in. + :type segment_name: str + :param key: Key to search for. + :type key: str + + :return: True if the segment contains the key. False otherwise. + :rtype: bool + """ + try: + res = self._redis.sismember(self._get_key(segment_name), key) + _LOGGER.debug("Checking Segment [%s] contain key [%s] in redis: %s" % (segment_name, key, res)) + return res + + except RedisAdapterException: + _LOGGER.error('Error testing members in segment stored in redis') + _LOGGER.debug('Error: ', exc_info=True) + return None + + +class RedisSegmentStorageAsync(RedisSegmentStorageBase): + """Redis based segment storage async class.""" + + def __init__(self, redis_client): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + """ + self._redis = redis_client + + async def get(self, segment_name): + """ + Retrieve a segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str + + :return: Segment object is key exists. None otherwise. + :rtype: splitio.models.segments.Segment + """ + try: + keys = (await self._redis.smembers(self._get_key(segment_name))) + _LOGGER.debug("Fetchting Segment [%s] from redis" % segment_name) + _LOGGER.debug(keys) + till = await self.get_change_number(segment_name) + if not keys or till is None: + return None + + return segments.Segment(segment_name, keys, till) + + except RedisAdapterException: + _LOGGER.error('Error fetching segment from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def get_change_number(self, segment_name): + """ + Retrieve latest change number for a segment. + + :param segment_name: Name of the segment. + :type segment_name: str + + :rtype: int + """ + try: + stored_value = await self._redis.get(self._get_till_key(segment_name)) + _LOGGER.debug("Fetchting Change Number for Segment [%s] from redis: " % stored_value) + return json.loads(stored_value) if stored_value is not None else None + + except RedisAdapterException: + _LOGGER.error('Error fetching segment change number from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def segment_contains(self, segment_name, key): + """ + Check whether a specific key belongs to a segment in storage. + + :param segment_name: Name of the segment to search in. + :type segment_name: str + :param key: Key to search for. + :type key: str + + :return: True if the segment contains the key. False otherwise. + :rtype: bool + """ + try: + res = await self._redis.sismember(self._get_key(segment_name), key) + _LOGGER.debug("Checking Segment [%s] contain key [%s] in redis: %s" % (segment_name, key, res)) + return res + + except RedisAdapterException: + _LOGGER.error('Error testing members in segment stored in redis') + _LOGGER.debug('Error: ', exc_info=True) + return None + + +class RedisImpressionsStorageBase(ImpressionStorage, ImpressionPipelinedStorage): + """Redis based event storage base class.""" + + IMPRESSIONS_QUEUE_KEY = 'SPLITIO.impressions' + IMPRESSIONS_KEY_DEFAULT_TTL = 3600 + + def _wrap_impressions(self, impressions): + """ + Wrap impressions to be stored in redis + + :param impressions: Impression to add to the queue. + :type impressions: splitio.models.impressions.Impression + + :return: Processed impressions. + :rtype: list[splitio.models.impressions.Impression] """ bulk_impressions = [] for impression in impressions: @@ -391,12 +1106,132 @@ def _wrap_impressions(self, impressions): 'r': impression.label, 'c': impression.change_number, 'm': impression.time, + 'properties': impression.properties } } bulk_impressions.append(json.dumps(to_store)) return bulk_impressions - def expire_key(self, total_keys, inserted): + def expire_key(self, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + pass + + def add_impressions_to_pipe(self, impressions, pipe): + """ + Add put operation to pipeline + + :param impressions: List of one or more impressions to store. + :type impressions: list + :param pipe: Redis pipe. + :type pipe: redis.pipe + """ + bulk_impressions = self._wrap_impressions(impressions) + _LOGGER.debug("Adding Impressions to redis key %s" % (self.IMPRESSIONS_QUEUE_KEY)) + _LOGGER.debug(bulk_impressions) + pipe.rpush(self.IMPRESSIONS_QUEUE_KEY, *bulk_impressions) + + def put(self, impressions): + """ + Add an impression to the redis storage. + + :param impressions: Impression to add to the queue. + :type impressions: splitio.models.impressions.Impression + + :return: Whether the impression has been added or not. + :rtype: bool + """ + pass + + def pop_many(self, count): + """ + Pop the oldest N events from storage. + + :param count: Number of events to pop. + :type count: int + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def clear(self): + """ + Clear data. + """ + raise NotImplementedError('Not supported for redis.') + + +class RedisImpressionsStorage(RedisImpressionsStorageBase): + """Redis based event storage class.""" + + def __init__(self, redis_client, sdk_metadata): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + """ + self._redis = redis_client + self._sdk_metadata = sdk_metadata + + def expire_key(self, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + self._redis.expire(self.IMPRESSIONS_QUEUE_KEY, self.IMPRESSIONS_KEY_DEFAULT_TTL) + + def put(self, impressions): + """ + Add an impression to the redis storage. + + :param impressions: Impression to add to the queue. + :type impressions: splitio.models.impressions.Impression + + :return: Whether the impression has been added or not. + :rtype: bool + """ + bulk_impressions = self._wrap_impressions(impressions) + try: + _LOGGER.debug("Adding Impressions to redis key %s" % (self.IMPRESSIONS_QUEUE_KEY)) + _LOGGER.debug(bulk_impressions) + inserted = self._redis.rpush(self.IMPRESSIONS_QUEUE_KEY, *bulk_impressions) + self.expire_key(inserted, len(bulk_impressions)) + return True + + except RedisAdapterException: + _LOGGER.error('Something went wrong when trying to add impression to redis') + _LOGGER.error('Error: ', exc_info=True) + return False + + +class RedisImpressionsStorageAsync(RedisImpressionsStorageBase): + """Redis based event storage async class.""" + + def __init__(self, redis_client, sdk_metadata): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + """ + self._redis = redis_client + self._sdk_metadata = sdk_metadata + + async def expire_key(self, total_keys, inserted): """ Set expire @@ -406,22 +1241,9 @@ def expire_key(self, total_keys, inserted): :type inserted: int """ if total_keys == inserted: - _LOGGER.debug("SET EXPIRE KEY FOR QUEUE") - self._redis.expire(self.IMPRESSIONS_QUEUE_KEY, self.IMPRESSIONS_KEY_DEFAULT_TTL) - - def add_impressions_to_pipe(self, impressions, pipe): - """ - Add put operation to pipeline - - :param impressions: List of one or more impressions to store. - :type impressions: list - :param pipe: Redis pipe. - :type pipe: redis.pipe - """ - bulk_impressions = self._wrap_impressions(impressions) - pipe.rpush(self.IMPRESSIONS_QUEUE_KEY, *bulk_impressions) + await self._redis.expire(self.IMPRESSIONS_QUEUE_KEY, self.IMPRESSIONS_KEY_DEFAULT_TTL) - def put(self, impressions): + async def put(self, impressions): """ Add an impression to the redis storage. @@ -433,14 +1255,70 @@ def put(self, impressions): """ bulk_impressions = self._wrap_impressions(impressions) try: - inserted = self._redis.rpush(self.IMPRESSIONS_QUEUE_KEY, *bulk_impressions) - self.expire_key(inserted, len(bulk_impressions)) + _LOGGER.debug("Adding Impressions to redis key %s" % (self.IMPRESSIONS_QUEUE_KEY)) + _LOGGER.debug(bulk_impressions) + inserted = await self._redis.rpush(self.IMPRESSIONS_QUEUE_KEY, *bulk_impressions) + await self.expire_key(inserted, len(bulk_impressions)) return True + except RedisAdapterException: _LOGGER.error('Something went wrong when trying to add impression to redis') _LOGGER.error('Error: ', exc_info=True) return False + +class RedisEventsStorageBase(EventStorage): + """Redis based event storage base class.""" + + _EVENTS_KEY_TEMPLATE = 'SPLITIO.events' + _EVENTS_KEY_DEFAULT_TTL = 3600 + + def add_events_to_pipe(self, events, pipe): + """ + Add put operation to pipeline + + :param impressions: List of one or more impressions to store. + :type impressions: list + :param pipe: Redis pipe. + :type pipe: redis.pipe + """ + bulk_events = self._wrap_events(events) + _LOGGER.debug("Adding Events to redis key %s" % (self._EVENTS_KEY_TEMPLATE)) + _LOGGER.debug(bulk_events) + pipe.rpush(self._EVENTS_KEY_TEMPLATE, *bulk_events) + + def _wrap_events(self, events): + return [ + json.dumps({ + 'e': { + 'key': e.event.key, + 'trafficTypeName': e.event.traffic_type_name, + 'eventTypeId': e.event.event_type_id, + 'value': e.event.value, + 'timestamp': e.event.timestamp, + 'properties': e.event.properties, + }, + 'm': { + 's': self._sdk_metadata.sdk_version, + 'n': self._sdk_metadata.instance_name, + 'i': self._sdk_metadata.instance_ip, + } + }) + for e in events + ] + + def put(self, events): + """ + Add an event to the redis storage. + + :param event: Event to add to the queue. + :type event: splitio.models.events.Event + + :return: Whether the event has been added or not. + :rtype: bool + """ + pass + def pop_many(self, count): """ Pop the oldest N events from storage. @@ -456,11 +1334,19 @@ def clear(self): """ raise NotImplementedError('Not supported for redis.') + def expire_keys(self, total_keys, inserted): + """ + Set expire -class RedisEventsStorage(EventStorage): - """Redis based event storage class.""" + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + pass - _KEY_TEMPLATE = 'SPLITIO.events' +class RedisEventsStorage(RedisEventsStorageBase): + """Redis based event storage class.""" def __init__(self, redis_client, sdk_metadata): """ @@ -484,44 +1370,404 @@ def put(self, events): :return: Whether the event has been added or not. :rtype: bool """ - key = self._KEY_TEMPLATE - to_store = [ - json.dumps({ - 'e': { - 'key': e.event.key, - 'trafficTypeName': e.event.traffic_type_name, - 'eventTypeId': e.event.event_type_id, - 'value': e.event.value, - 'timestamp': e.event.timestamp, - 'properties': e.event.properties, - }, - 'm': { - 's': self._sdk_metadata.sdk_version, - 'n': self._sdk_metadata.instance_name, - 'i': self._sdk_metadata.instance_ip, - } - }) - for e in events - ] + key = self._EVENTS_KEY_TEMPLATE + to_store = self._wrap_events(events) try: + _LOGGER.debug("Adding Events to redis key %s" % (key)) + _LOGGER.debug(to_store) self._redis.rpush(key, *to_store) return True + except RedisAdapterException: _LOGGER.error('Something went wrong when trying to add event to redis') _LOGGER.debug('Error: ', exc_info=True) return False - def pop_many(self, count): + def expire_keys(self, total_keys, inserted): """ - Pop the oldest N events from storage. + Set expire - :param count: Number of events to pop. - :type count: int + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int """ - raise NotImplementedError('Only redis-consumer mode is supported.') + if total_keys == inserted: + self._redis.expire(self._EVENTS_KEY_TEMPLATE, self._EVENTS_KEY_DEFAULT_TTL) - def clear(self): + +class RedisEventsStorageAsync(RedisEventsStorageBase): + """Redis based event async storage class.""" + + def __init__(self, redis_client, sdk_metadata): """ - Clear data. + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata """ - raise NotImplementedError('Not supported for redis.') + self._redis = redis_client + self._sdk_metadata = sdk_metadata + + async def put(self, events): + """ + Add an event to the redis storage. + + :param event: Event to add to the queue. + :type event: splitio.models.events.Event + + :return: Whether the event has been added or not. + :rtype: bool + """ + key = self._EVENTS_KEY_TEMPLATE + to_store = self._wrap_events(events) + try: + _LOGGER.debug("Adding Events to redis key %s" % (key)) + _LOGGER.debug(to_store) + await self._redis.rpush(key, *to_store) + return True + + except RedisAdapterException: + _LOGGER.error('Something went wrong when trying to add event to redis') + _LOGGER.debug('Error: ', exc_info=True) + return False + + async def expire_keys(self, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + await self._redis.expire(self._EVENTS_KEY_TEMPLATE, self._EVENTS_KEY_DEFAULT_TTL) + + +class RedisTelemetryStorageBase(TelemetryStorage): + """Redis based telemetry storage class.""" + + _TELEMETRY_CONFIG_KEY = 'SPLITIO.telemetry.init' + _TELEMETRY_LATENCIES_KEY = 'SPLITIO.telemetry.latencies' + _TELEMETRY_EXCEPTIONS_KEY = 'SPLITIO.telemetry.exceptions' + _TELEMETRY_KEY_DEFAULT_TTL = 3600 + + def _reset_config_tags(self): + """Reset all config tags""" + pass + + def add_config_tag(self, tag): + """Record tag string.""" + pass + + def record_config(self, config, extra_config, total_flag_sets, invalid_flag_sets): + """ + initilize telemetry objects + + :param congif: factory configuration parameters + :type config: splitio.client.config + """ + pass + + def pop_config_tags(self): + """Get and reset tags.""" + pass + + def push_config_stats(self): + """push config stats to redis.""" + pass + + def _format_config_stats(self, config_stats, tags): + """format only selected config stats to json""" + return json.dumps({ + 'aF': config_stats['aF'], + 'rF': config_stats['rF'], + 'sT': config_stats['sT'], + 'oM': config_stats['oM'], + 't': tags + }) + + def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """Record active and redundant factories.""" + pass + + def add_latency_to_pipe(self, method, bucket, pipe): + """ + record latency data + + :param method: method name + :type method: string + :param latency: latency + :type latency: int64 + :param pipe: Redis pipe. + :type pipe: redis.pipe + """ + _LOGGER.debug("Adding Latency stats to redis key %s" % (self._TELEMETRY_LATENCIES_KEY)) + _LOGGER.debug(self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip + '/' + + method.value + '/' + str(bucket)) + pipe.hincrby(self._TELEMETRY_LATENCIES_KEY, self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip + '/' + + method.value + '/' + str(bucket), 1) + + def record_latency(self, method, latency): + """ + Not implemented + """ + raise NotImplementedError('Only redis pipe is used.') + + def record_exception(self, method): + """ + record an exception + + :param method: method name + :type method: string + """ + pass + + def record_not_ready_usage(self): + """ + record not ready time + + """ + pass + + def record_bur_time_out(self): + """ + record BUR timeouts + + """ + pass + + def record_impression_stats(self, data_type, count): + pass + + def expire_latency_keys(self, total_keys, inserted): + pass + + def expire_keys(self, queue_key, key_default_ttl, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + pass + + +class RedisTelemetryStorage(RedisTelemetryStorageBase): + """Redis based telemetry storage class.""" + + def __init__(self, redis_client, sdk_metadata): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + """ + self._lock = threading.RLock() + self._reset_config_tags() + self._redis_client = redis_client + self._sdk_metadata = sdk_metadata + self._tel_config = TelemetryConfig() + self._make_pipe = redis_client.pipeline + + def _reset_config_tags(self): + """Reset all config tags""" + with self._lock: + self._config_tags = [] + + def add_config_tag(self, tag): + """Record tag string.""" + with self._lock: + if len(self._config_tags) < MAX_TAGS: + self._config_tags.append(tag) + + def record_config(self, config, extra_config, total_flag_sets, invalid_flag_sets): + """ + initilize telemetry objects + + :param congif: factory configuration parameters + :type config: splitio.client.config + """ + self._tel_config.record_config(config, extra_config, total_flag_sets, invalid_flag_sets) + + def pop_config_tags(self): + """Get and reset tags.""" + with self._lock: + tags = self._config_tags + self._reset_config_tags() + return tags + + def push_config_stats(self): + """push config stats to redis.""" + _LOGGER.debug("Adding Config stats to redis key %s" % (self._TELEMETRY_CONFIG_KEY)) + _LOGGER.debug(str(self._format_config_stats(self._tel_config.get_stats(), self.pop_config_tags()))) + self._redis_client.hset(self._TELEMETRY_CONFIG_KEY, self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip, str(self._format_config_stats(self._tel_config.get_stats(), self.pop_config_tags()))) + + def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """Record active and redundant factories.""" + self._tel_config.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + + def record_exception(self, method): + """ + record an exception + + :param method: method name + :type method: string + """ + _LOGGER.debug("Adding Excepction stats to redis key %s" % (self._TELEMETRY_EXCEPTIONS_KEY)) + _LOGGER.debug(self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip + '/' + + method.value) + pipe = self._make_pipe() + pipe.hincrby(self._TELEMETRY_EXCEPTIONS_KEY, self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip + '/' + + method.value, 1) + result = pipe.execute() + self.expire_keys(self._TELEMETRY_EXCEPTIONS_KEY, self._TELEMETRY_KEY_DEFAULT_TTL, 1, result[0]) + + def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """Record active and redundant factories.""" + self._tel_config.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + + def expire_latency_keys(self, total_keys, inserted): + """ + Expire lstency keys + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + self.expire_keys(self._TELEMETRY_LATENCIES_KEY, self._TELEMETRY_KEY_DEFAULT_TTL, total_keys, inserted) + + def expire_keys(self, queue_key, key_default_ttl, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + self._redis_client.expire(queue_key, key_default_ttl) + + def record_bur_time_out(self): + """record BUR timeouts""" + pass + + def record_ready_time(self, ready_time): + """Record ready time.""" + pass + + +class RedisTelemetryStorageAsync(RedisTelemetryStorageBase): + """Redis based telemetry async storage class.""" + + @classmethod + async def create(cls, redis_client, sdk_metadata): + """ + Create instance and reset tags + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + + :return: self instance. + :rtype: splitio.storage.redis.RedisTelemetryStorageAsync + """ + self = cls() + await self._reset_config_tags() + self._redis_client = redis_client + self._sdk_metadata = sdk_metadata + self._tel_config = await TelemetryConfigAsync.create() + self._make_pipe = redis_client.pipeline + return self + + async def _reset_config_tags(self): + """Reset all config tags""" + self._config_tags = [] + + async def add_config_tag(self, tag): + """Record tag string.""" + if len(self._config_tags) < MAX_TAGS: + self._config_tags.append(tag) + + async def record_config(self, config, extra_config, total_flag_sets, invalid_flag_sets): + """ + initilize telemetry objects + + :param congif: factory configuration parameters + :type config: splitio.client.config + """ + await self._tel_config.record_config(config, extra_config, total_flag_sets, invalid_flag_sets) + + async def record_bur_time_out(self): + """record BUR timeouts""" + pass + + async def record_ready_time(self, ready_time): + """Record ready time.""" + pass + + async def pop_config_tags(self): + """Get and reset tags.""" + tags = self._config_tags + await self._reset_config_tags() + return tags + + async def push_config_stats(self): + """push config stats to redis.""" + _LOGGER.debug("Adding Config stats to redis key %s" % (self._TELEMETRY_CONFIG_KEY)) + stats = str(self._format_config_stats(await self._tel_config.get_stats(), await self.pop_config_tags())) + _LOGGER.debug(stats) + await self._redis_client.hset(self._TELEMETRY_CONFIG_KEY, self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip, stats) + + async def record_exception(self, method): + """ + record an exception + + :param method: method name + :type method: string + """ + _LOGGER.debug("Adding Excepction stats to redis key %s" % (self._TELEMETRY_EXCEPTIONS_KEY)) + _LOGGER.debug(self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip + '/' + + method.value) + pipe = self._make_pipe() + pipe.hincrby(self._TELEMETRY_EXCEPTIONS_KEY, self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip + '/' + + method.value, 1) + result = await pipe.execute() + await self.expire_keys(self._TELEMETRY_EXCEPTIONS_KEY, self._TELEMETRY_KEY_DEFAULT_TTL, 1, result[0]) + + async def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """Record active and redundant factories.""" + await self._tel_config.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + + async def expire_latency_keys(self, total_keys, inserted): + """ + Expire lstency keys + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + await self.expire_keys(self._TELEMETRY_LATENCIES_KEY, self._TELEMETRY_KEY_DEFAULT_TTL, total_keys, inserted) + + async def expire_keys(self, queue_key, key_default_ttl, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + await self._redis_client.expire(queue_key, key_default_ttl) diff --git a/splitio/sync/event.py b/splitio/sync/event.py index 06c944b0..ff761670 100644 --- a/splitio/sync/event.py +++ b/splitio/sync/event.py @@ -2,12 +2,13 @@ import queue from splitio.api import APIException - +from splitio.optional.loaders import asyncio _LOGGER = logging.getLogger(__name__) class EventSynchronizer(object): + """Event Synchronizer class""" def __init__(self, events_api, storage, bulk_size): """ Class constructor. @@ -65,3 +66,64 @@ def synchronize_events(self): _LOGGER.error('Exception raised while reporting events') _LOGGER.debug('Exception information: ', exc_info=True) self._add_to_failed_queue(to_send) + + +class EventSynchronizerAsync(object): + """Event Synchronizer async class""" + def __init__(self, events_api, storage, bulk_size): + """ + Class constructor. + + :param events_api: Events Api object to send data to the backend + :type events_api: splitio.api.events.EventsAPI + :param storage: Events Storage + :type storage: splitio.storage.EventStorage + :param bulk_size: How many events to send per push. + :type bulk_size: int + + """ + self._api = events_api + self._event_storage = storage + self._bulk_size = bulk_size + self._failed = asyncio.Queue() + + async def _get_failed(self): + """Return up to events stored in the failed eventes queue.""" + events = [] + count = 0 + while count < self._bulk_size and self._failed.qsize() > 0: + try: + events.append(await self._failed.get()) + count += 1 + except asyncio.QueueEmpty: + # If no more items in queue, break the loop + break + return events + + async def _add_to_failed_queue(self, events): + """ + Add events that were about to be sent to a secondary queue for failed sends. + + :param events: List of events that failed to be pushed. + :type events: list + """ + for event in events: + await self._failed.put(event) + + async def synchronize_events(self): + """Send events from both the failed and new queues.""" + to_send = await self._get_failed() + if len(to_send) < self._bulk_size: + # If the amount of previously failed items is less than the bulk + # size, try to complete with new events from storage + to_send.extend(await self._event_storage.pop_many(self._bulk_size - len(to_send))) + + if not to_send: + return + + try: + await self._api.flush_events(to_send) + except APIException: + _LOGGER.error('Exception raised while reporting events') + _LOGGER.debug('Exception information: ', exc_info=True) + await self._add_to_failed_queue(to_send) diff --git a/splitio/sync/impression.py b/splitio/sync/impression.py index 51505d1c..8fd54051 100644 --- a/splitio/sync/impression.py +++ b/splitio/sync/impression.py @@ -2,12 +2,13 @@ import queue from splitio.api import APIException - +from splitio.optional.loaders import asyncio _LOGGER = logging.getLogger(__name__) class ImpressionSynchronizer(object): + """Impressions synchronizer class.""" def __init__(self, impressions_api, storage, bulk_size): """ Class constructor. @@ -68,7 +69,7 @@ def synchronize_impressions(self): class ImpressionsCountSynchronizer(object): - def __init__(self, impressions_api, impressions_manager): + def __init__(self, impressions_api, imp_counter): """ Class constructor. @@ -79,11 +80,15 @@ def __init__(self, impressions_api, impressions_manager): """ self._impressions_api = impressions_api - self._impressions_manager = impressions_manager + self._impressions_counter = imp_counter def synchronize_counters(self): """Send impressions from both the failed and new queues.""" - to_send = self._impressions_manager.get_counts() + + if self._impressions_counter == None: + return + + to_send = self._impressions_counter.pop_all() if not to_send: return @@ -92,3 +97,95 @@ def synchronize_counters(self): except APIException: _LOGGER.error('Exception raised while reporting impression counts') _LOGGER.debug('Exception information: ', exc_info=True) + + +class ImpressionSynchronizerAsync(object): + """Impressions async synchronizer class.""" + def __init__(self, impressions_api, storage, bulk_size): + """ + Class constructor. + + :param impressions_api: Impressions Api object to send data to the backend + :type impressions_api: splitio.api.impressions.ImpressionsAPI + :param storage: Impressions Storage + :type storage: splitio.storage.ImpressionsStorage + :param bulk_size: How many impressions to send per push. + :type bulk_size: int + + """ + self._api = impressions_api + self._impression_storage = storage + self._bulk_size = bulk_size + self._failed = asyncio.Queue() + + async def _get_failed(self): + """Return up to impressions stored in the failed impressions queue.""" + imps = [] + count = 0 + while count < self._bulk_size and self._failed.qsize() > 0: + try: + imps.append(await self._failed.get()) + count += 1 + except asyncio.QueueEmpty: + # If no more items in queue, break the loop + break + return imps + + async def _add_to_failed_queue(self, imps): + """ + Add impressions that were about to be sent to a secondary queue for failed sends. + + :param imps: List of impressions that failed to be pushed. + :type imps: list + """ + for impression in imps: + await self._failed.put(impression) + + async def synchronize_impressions(self): + """Send impressions from both the failed and new queues.""" + to_send = await self._get_failed() + if len(to_send) < self._bulk_size: + # If the amount of previously failed items is less than the bulk + # size, try to complete with new impressions from storage + to_send.extend(await self._impression_storage.pop_many(self._bulk_size - len(to_send))) + + if not to_send: + return + + try: + await self._api.flush_impressions(to_send) + except APIException: + _LOGGER.error('Exception raised while reporting impressions') + _LOGGER.debug('Exception information: ', exc_info=True) + await self._add_to_failed_queue(to_send) + + +class ImpressionsCountSynchronizerAsync(object): + def __init__(self, impressions_api, imp_counter): + """ + Class constructor. + + :param impressions_api: Impressions Api object to send data to the backend + :type impressions_api: splitio.api.impressions.ImpressionsAPI + :param impressions_manager: Impressions manager instance + :type impressions_manager: splitio.engine.impressions.Manager + + """ + self._impressions_api = impressions_api + self._impressions_counter = imp_counter + + async def synchronize_counters(self): + """Send impressions from both the failed and new queues.""" + + if self._impressions_counter == None: + return + + to_send = self._impressions_counter.pop_all() + if not to_send: + return + + try: + await self._impressions_api.flush_counters(to_send) + except APIException: + _LOGGER.error('Exception raised while reporting impression counts') + _LOGGER.debug('Exception information: ', exc_info=True) diff --git a/splitio/sync/manager.py b/splitio/sync/manager.py index 700f2dfe..7254a92e 100644 --- a/splitio/sync/manager.py +++ b/splitio/sync/manager.py @@ -3,10 +3,14 @@ import time from threading import Thread from queue import Queue -from splitio.push.manager import PushManager, Status + +from splitio.optional.loaders import asyncio +from splitio.push.manager import PushManager, PushManagerAsync, Status from splitio.api import APIException from splitio.util.backoff import Backoff - +from splitio.util.time import get_current_epoch_time_ms +from splitio.models.telemetry import SSESyncMode, StreamingEventTypes +from splitio.sync.synchronizer import _SYNC_ALL_NO_RETRIES _LOGGER = logging.getLogger(__name__) @@ -16,7 +20,7 @@ class Manager(object): # pylint:disable=too-many-instance-attributes _CENTINEL_EVENT = object() - def __init__(self, ready_flag, synchronizer, auth_api, streaming_enabled, sdk_metadata, sse_url=None, client_key=None): # pylint:disable=too-many-arguments + def __init__(self, ready_flag, synchronizer, auth_api, streaming_enabled, sdk_metadata, telemetry_runtime_producer, sse_url=None, client_key=None): # pylint:disable=too-many-arguments """ Construct Manager. @@ -44,23 +48,23 @@ def __init__(self, ready_flag, synchronizer, auth_api, streaming_enabled, sdk_me self._streaming_enabled = streaming_enabled self._ready_flag = ready_flag self._synchronizer = synchronizer + self._telemetry_runtime_producer = telemetry_runtime_producer if self._streaming_enabled: self._push_status_handler_active = True self._backoff = Backoff() self._queue = Queue() - self._push = PushManager(auth_api, synchronizer, self._queue, sdk_metadata, sse_url, client_key) + self._push = PushManager(auth_api, synchronizer, self._queue, sdk_metadata, telemetry_runtime_producer, sse_url, client_key) self._push_status_handler = Thread(target=self._streaming_feedback_handler, - name='PushStatusHandler') - self._push_status_handler.setDaemon(True) + name='PushStatusHandler', daemon=True) def recreate(self): """Recreate poolers for forked processes.""" self._synchronizer._split_synchronizers._segment_sync.recreate() - def start(self): + def start(self, max_retry_attempts=_SYNC_ALL_NO_RETRIES): """Start the SDK synchronization tasks.""" try: - self._synchronizer.sync_all() + self._synchronizer.sync_all(max_retry_attempts) self._ready_flag.set() self._synchronizer.start_periodic_data_recording() if self._streaming_enabled: @@ -106,11 +110,13 @@ def _streaming_feedback_handler(self): self._push.update_workers_status(True) self._backoff.reset() _LOGGER.info('streaming up and running. disabling periodic fetching.') + self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.SYNC_MODE_UPDATE, SSESyncMode.STREAMING.value, get_current_epoch_time_ms())) elif status == Status.PUSH_SUBSYSTEM_DOWN: self._push.update_workers_status(False) self._synchronizer.sync_all() self._synchronizer.start_periodic_fetching() _LOGGER.info('streaming temporarily down. starting periodic fetching') + self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.SYNC_MODE_UPDATE, SSESyncMode.POLLING.value, get_current_epoch_time_ms())) elif status == Status.PUSH_RETRYABLE_ERROR: self._push.update_workers_status(False) self._push.stop(True) @@ -126,4 +132,192 @@ def _streaming_feedback_handler(self): self._synchronizer.sync_all() self._synchronizer.start_periodic_fetching() _LOGGER.info('non-recoverable error in streaming. switching to polling.') + self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.SYNC_MODE_UPDATE, SSESyncMode.POLLING.value, get_current_epoch_time_ms())) + return + + +class ManagerAsync(object): # pylint:disable=too-many-instance-attributes + """Manager Class.""" + + _CENTINEL_EVENT = object() + + def __init__(self, synchronizer, auth_api, streaming_enabled, sdk_metadata, telemetry_runtime_producer, sse_url=None, client_key=None): # pylint:disable=too-many-arguments + """ + Construct Manager. + + :param split_synchronizers: synchronizers for performing start/stop logic + :type split_synchronizers: splitio.sync.synchronizer.Synchronizer + + :param auth_api: Authentication api client + :type auth_api: splitio.api.auth.AuthAPI + + :param sdk_metadata: SDK version & machine name & IP. + :type sdk_metadata: splitio.client.util.SdkMetadata + + :param streaming_enabled: whether to use streaming or not + :type streaming_enabled: bool + + :param sse_url: streaming base url. + :type sse_url: str + + :param client_key: client key. + :type client_key: str + """ + self._streaming_enabled = streaming_enabled + self._synchronizer = synchronizer + self._telemetry_runtime_producer = telemetry_runtime_producer + if self._streaming_enabled: + self._push_status_handler_active = True + self._backoff = Backoff() + self._queue = asyncio.Queue() + self._push = PushManagerAsync(auth_api, synchronizer, self._queue, sdk_metadata, telemetry_runtime_producer, sse_url, client_key) + self._stopped = False + + async def start(self, max_retry_attempts=_SYNC_ALL_NO_RETRIES): + """Start the SDK synchronization tasks.""" + self._stopped = False + try: + await self._synchronizer.sync_all(max_retry_attempts) + if not self._stopped: + self._synchronizer.start_periodic_data_recording() + if self._streaming_enabled: + asyncio.get_running_loop().create_task(self._streaming_feedback_handler()) + self._push.start() + else: + self._synchronizer.start_periodic_fetching() + except (APIException, RuntimeError): + _LOGGER.error('Exception raised starting Split Manager') + _LOGGER.debug('Exception information: ', exc_info=True) + raise + + async def stop(self, blocking): + """ + Stop manager logic. + + :param blocking: flag to wait until tasks are stopped + :type blocking: bool + """ + _LOGGER.info('Stopping manager tasks') + if self._streaming_enabled: + self._push_status_handler_active = False + await self._queue.put(self._CENTINEL_EVENT) + await self._push.stop(blocking) + await self._push.close_sse_http_client() + await self._synchronizer.shutdown(blocking) + self._stopped = True + + async def _streaming_feedback_handler(self): + """ + Handle status updates from the streaming subsystem. + + :param status: current status of the streaming pipeline. + :type status: splitio.push.status_stracker.Status + """ + while self._push_status_handler_active: + status = await self._queue.get() + if status == self._CENTINEL_EVENT: + continue + if status == Status.PUSH_SUBSYSTEM_UP: + await self._synchronizer.stop_periodic_fetching() + await self._synchronizer.sync_all() + await self._push.update_workers_status(True) + self._backoff.reset() + _LOGGER.info('streaming up and running. disabling periodic fetching.') + await self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.SYNC_MODE_UPDATE, SSESyncMode.STREAMING.value, get_current_epoch_time_ms())) + elif status == Status.PUSH_SUBSYSTEM_DOWN: + await self._push.update_workers_status(False) + await self._synchronizer.sync_all() + self._synchronizer.start_periodic_fetching() + _LOGGER.info('streaming temporarily down. starting periodic fetching') + await self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.SYNC_MODE_UPDATE, SSESyncMode.POLLING.value, get_current_epoch_time_ms())) + elif status == Status.PUSH_RETRYABLE_ERROR: + await self._push.update_workers_status(False) + await self._push.stop(True) + await self._synchronizer.sync_all() + self._synchronizer.start_periodic_fetching() + how_long = self._backoff.get() + _LOGGER.info('error in streaming. restarting flow in %d seconds', how_long) + await asyncio.sleep(how_long) + self._push.start() + elif status == Status.PUSH_NONRETRYABLE_ERROR: + await self._push.update_workers_status(False) + await self._push.stop(False) + await self._synchronizer.sync_all() + self._synchronizer.start_periodic_fetching() + _LOGGER.info('non-recoverable error in streaming. switching to polling.') + await self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.SYNC_MODE_UPDATE, SSESyncMode.POLLING.value, get_current_epoch_time_ms())) return + + +class RedisManagerBase(object): # pylint:disable=too-many-instance-attributes + """Manager base Class.""" + + def __init__(self, synchronizer): # pylint:disable=too-many-arguments + """ + Construct Manager. + + :param synchronizer: synchronizers for performing start/stop logic + :type synchronizer: splitio.sync.synchronizer.Synchronizer + """ + self._ready_flag = True + self._synchronizer = synchronizer + + def recreate(self): + """Not implemented""" + return + + def start(self): + """Start the SDK synchronization tasks.""" + try: + self._synchronizer.start_periodic_data_recording() + + except (APIException, RuntimeError): + _LOGGER.error('Exception raised starting Split Manager') + _LOGGER.debug('Exception information: ', exc_info=True) + raise + + +class RedisManager(RedisManagerBase): # pylint:disable=too-many-instance-attributes + """Manager Class.""" + + def __init__(self, synchronizer): # pylint:disable=too-many-arguments + """ + Construct Manager. + + :param synchronizer: synchronizers for performing start/stop logic + :type synchronizer: splitio.sync.synchronizer.Synchronizer + """ + RedisManagerBase.__init__(self, synchronizer) + + def stop(self, blocking): + """ + Stop manager logic. + + :param blocking: flag to wait until tasks are stopped + :type blocking: bool + """ + _LOGGER.info('Stopping manager tasks') + self._synchronizer.shutdown(blocking) + + +class RedisManagerAsync(RedisManagerBase): # pylint:disable=too-many-instance-attributes + """Manager async Class.""" + + def __init__(self, synchronizer): # pylint:disable=too-many-arguments + """ + Construct Manager. + + :param synchronizer: synchronizers for performing start/stop logic + :type synchronizer: splitio.sync.synchronizer.Synchronizer + """ + RedisManagerBase.__init__(self, synchronizer) + + async def stop(self, blocking): + """ + Stop manager logic. + + :param blocking: flag to wait until tasks are stopped + :type blocking: bool + """ + _LOGGER.info('Stopping manager tasks') + await self._synchronizer.shutdown(blocking) diff --git a/splitio/sync/segment.py b/splitio/sync/segment.py index 37e453bf..a87759e1 100644 --- a/splitio/sync/segment.py +++ b/splitio/sync/segment.py @@ -1,12 +1,17 @@ import logging import time +import json +import os from splitio.api import APIException from splitio.api.commons import FetchOptions from splitio.tasks.util import workerpool from splitio.models import segments from splitio.util.backoff import Backoff - +from splitio.optional.loaders import asyncio, aiofiles +from splitio.sync import util +from splitio.util.storage_helper import get_standard_segment_names_in_rbs_storage, get_standard_segment_names_in_rbs_storage_async +from splitio.optional.loaders import asyncio _LOGGER = logging.getLogger(__name__) @@ -14,27 +19,29 @@ _ON_DEMAND_FETCH_BACKOFF_BASE = 10 # backoff base starting at 10 seconds _ON_DEMAND_FETCH_BACKOFF_MAX_WAIT = 60 # don't sleep for more than 1 minute _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES = 10 +_MAX_WORKERS = 10 class SegmentSynchronizer(object): - def __init__(self, segment_api, split_storage, segment_storage): + def __init__(self, segment_api, feature_flag_storage, segment_storage, rule_based_segment_storage): """ Class constructor. :param segment_api: API to retrieve segments from backend. :type segment_api: splitio.api.SegmentApi - :param split_storage: Split Storage. - :type split_storage: splitio.storage.InMemorySplitStorage + :param feature_flag_storage: Feature Flag Storage. + :type feature_flag_storage: splitio.storage.InMemorySplitStorage :param segment_storage: Segment storage reference. :type segment_storage: splitio.storage.SegmentStorage """ self._api = segment_api - self._split_storage = split_storage + self._feature_flag_storage = feature_flag_storage self._segment_storage = segment_storage - self._worker_pool = workerpool.WorkerPool(10, self.synchronize_segment) + self._rule_based_segment_storage = rule_based_segment_storage + self._worker_pool = workerpool.WorkerPool(_MAX_WORKERS, self.synchronize_segment) self._worker_pool.start() self._backoff = Backoff( _ON_DEMAND_FETCH_BACKOFF_BASE, @@ -45,7 +52,7 @@ def recreate(self): Create worker_pool on forked processes. """ - self._worker_pool = workerpool.WorkerPool(10, self.synchronize_segment) + self._worker_pool = workerpool.WorkerPool(_MAX_WORKERS, self.synchronize_segment) self._worker_pool.start() def shutdown(self): @@ -108,7 +115,7 @@ def _attempt_segment_sync(self, segment_name, fetch_options, till=None): :param segment_name: Name of the segment to update. :type segment_name: str - :param fetch_options Fetch options for getting split definitions. + :param fetch_options Fetch options for getting feature flag definitions. :type fetch_options splitio.api.FetchOptions :param till: Passed till from Streaming. @@ -124,8 +131,10 @@ def _attempt_segment_sync(self, segment_name, fetch_options, till=None): change_number = self._fetch_until(segment_name, fetch_options, till) if till is None or till <= change_number: return True, remaining_attempts, change_number + elif remaining_attempts <= 0: return False, remaining_attempts, change_number + how_long = self._backoff.get() time.sleep(how_long) @@ -139,32 +148,499 @@ def synchronize_segment(self, segment_name, till=None): :param till: ChangeNumber received. :type till: int + :return: True if no error occurs. False otherwise. + :rtype: bool """ - fetch_options = FetchOptions(True) # Set Cache-Control to no-cache + fetch_options = FetchOptions(True, spec=None) # Set Cache-Control to no-cache successful_sync, remaining_attempts, change_number = self._attempt_segment_sync(segment_name, fetch_options, till) attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES - remaining_attempts if successful_sync: # succedeed sync _LOGGER.debug('Refresh completed in %d attempts.', attempts) - return - with_cdn_bypass = FetchOptions(True, change_number) # Set flag for bypassing CDN + return True + with_cdn_bypass = FetchOptions(True, change_number, spec=None) # Set flag for bypassing CDN without_cdn_successful_sync, remaining_attempts, change_number = self._attempt_segment_sync(segment_name, with_cdn_bypass, till) without_cdn_attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES - remaining_attempts if without_cdn_successful_sync: _LOGGER.debug('Refresh completed bypassing the CDN in %d attempts.', without_cdn_attempts) - return - else: - _LOGGER.debug('No changes fetched after %d attempts with CDN bypassed.', - without_cdn_attempts) + return True + + _LOGGER.debug('No changes fetched after %d attempts with CDN bypassed.', + without_cdn_attempts) + return False - def synchronize_segments(self): + def synchronize_segments(self, segment_names = None, dont_wait = False): """ - Submit all current segments and wait for them to finish, then set the ready flag. + Submit all current segments and wait for them to finish depend on dont_wait flag, then set the ready flag. - :return: True if no error occurs. False otherwise. + :param segment_names: Optional, array of segment names to update. + :type segment_name: {str} + + :param dont_wait: Optional, instruct the function to not wait for task completion + :type segment_name: boolean + + :return: True if no error occurs or dont_wait flag is True. False otherwise. :rtype: bool """ - segment_names = self._split_storage.get_segment_names() + if segment_names is None: + segment_names = set(self._feature_flag_storage.get_segment_names()) + segment_names.update(get_standard_segment_names_in_rbs_storage(self._rule_based_segment_storage)) + for segment_name in segment_names: + _LOGGER.debug("Adding segment name to sync worker") + _LOGGER.debug(segment_name) self._worker_pool.submit_work(segment_name) + if (dont_wait): + return True + return not self._worker_pool.wait_for_completion() + + def segment_exist_in_storage(self, segment_name): + """ + Check if a segment exists in the storage + + :param segment_name: Name of the segment + :type segment_name: str + + :return: True if segment exist. False otherwise. + :rtype: bool + """ + return self._segment_storage.get(segment_name) != None + + +class SegmentSynchronizerAsync(object): + def __init__(self, segment_api, feature_flag_storage, segment_storage, rule_based_segment_storage): + """ + Class constructor. + + :param segment_api: API to retrieve segments from backend. + :type segment_api: splitio.api.SegmentApi + + :param feature_flag_storage: Feature Flag Storage. + :type feature_flag_storage: splitio.storage.InMemorySplitStorage + + :param segment_storage: Segment storage reference. + :type segment_storage: splitio.storage.SegmentStorage + + """ + self._api = segment_api + self._feature_flag_storage = feature_flag_storage + self._segment_storage = segment_storage + self._rule_based_segment_storage = rule_based_segment_storage + self._worker_pool = workerpool.WorkerPoolAsync(_MAX_WORKERS, self.synchronize_segment) + self._worker_pool.start() + self._backoff = Backoff( + _ON_DEMAND_FETCH_BACKOFF_BASE, + _ON_DEMAND_FETCH_BACKOFF_MAX_WAIT) + + def recreate(self): + """ + Create worker_pool on forked processes. + + """ + self._worker_pool = workerpool.WorkerPoolAsync(_MAX_WORKERS, self.synchronize_segment) + self._worker_pool.start() + + async def shutdown(self): + """ + Shutdown worker_pool + + """ + await self._worker_pool.stop() + + async def _fetch_until(self, segment_name, fetch_options, till=None): + """ + Hit endpoint, update storage and return when since==till. + + :param segment_name: Name of the segment to update. + :type segment_name: str + + :param fetch_options Fetch options for getting segment definitions. + :type fetch_options splitio.api.FetchOptions + + :param till: Passed till from Streaming. + :type till: int + + :return: last change number + :rtype: int + """ + while True: # Fetch until since==till + change_number = await self._segment_storage.get_change_number(segment_name) + if change_number is None: + change_number = -1 + if till is not None and till < change_number: + # the passed till is less than change_number, no need to perform updates + return change_number + + try: + segment_changes = await self._api.fetch_segment(segment_name, change_number, + fetch_options) + except APIException as exc: + _LOGGER.error('Exception raised while fetching segment %s', segment_name) + _LOGGER.debug('Exception information: ', exc_info=True) + raise exc + + if change_number == -1: # first time fetching the segment + new_segment = segments.from_raw(segment_changes) + await self._segment_storage.put(new_segment) + else: + await self._segment_storage.update( + segment_name, + segment_changes['added'], + segment_changes['removed'], + segment_changes['till'] + ) + + if segment_changes['till'] == segment_changes['since']: + return segment_changes['till'] + + async def _attempt_segment_sync(self, segment_name, fetch_options, till=None): + """ + Hit endpoint, update storage and return True if sync is complete. + + :param segment_name: Name of the segment to update. + :type segment_name: str + + :param fetch_options Fetch options for getting feature flag definitions. + :type fetch_options splitio.api.FetchOptions + + :param till: Passed till from Streaming. + :type till: int + + :return: Flags to check if it should perform bypass or operation ended + :rtype: bool, int, int + """ + self._backoff.reset() + remaining_attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES + while True: + remaining_attempts -= 1 + change_number = await self._fetch_until(segment_name, fetch_options, till) + if till is None or till <= change_number: + return True, remaining_attempts, change_number + + elif remaining_attempts <= 0: + return False, remaining_attempts, change_number + + how_long = self._backoff.get() + await asyncio.sleep(how_long) + + async def synchronize_segment(self, segment_name, till=None): + """ + Update a segment from queue + + :param segment_name: Name of the segment to update. + :type segment_name: str + + :param till: ChangeNumber received. + :type till: int + + :return: True if no error occurs. False otherwise. + :rtype: bool + """ + fetch_options = FetchOptions(True, spec=None) # Set Cache-Control to no-cache + successful_sync, remaining_attempts, change_number = await self._attempt_segment_sync(segment_name, fetch_options, till) + attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES - remaining_attempts + if successful_sync: # succedeed sync + _LOGGER.debug('Refresh completed in %d attempts.', attempts) + return True + + with_cdn_bypass = FetchOptions(True, change_number, spec=None) # Set flag for bypassing CDN + without_cdn_successful_sync, remaining_attempts, change_number = await self._attempt_segment_sync(segment_name, with_cdn_bypass, till) + without_cdn_attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES - remaining_attempts + if without_cdn_successful_sync: + _LOGGER.debug('Refresh completed bypassing the CDN in %d attempts.', + without_cdn_attempts) + return True + + _LOGGER.debug('No changes fetched after %d attempts with CDN bypassed.', + without_cdn_attempts) + return False + + async def synchronize_segments(self, segment_names = None, dont_wait = False): + """ + Submit all current segments and wait for them to finish depend on dont_wait flag, then set the ready flag. + + :param segment_names: Optional, array of segment names to update. + :type segment_name: {str} + + :param dont_wait: Optional, instruct the function to not wait for task completion + :type segment_name: boolean + + :return: True if no error occurs or dont_wait flag is True. False otherwise. + :rtype: bool + """ + if segment_names is None: + segment_names = set(await self._feature_flag_storage.get_segment_names()) + segment_names.update(await get_standard_segment_names_in_rbs_storage_async(self._rule_based_segment_storage)) + + self._jobs = await self._worker_pool.submit_work(segment_names) + if (dont_wait): + return True + + return await self._jobs.await_completion() + + async def segment_exist_in_storage(self, segment_name): + """ + Check if a segment exists in the storage + + :param segment_name: Name of the segment + :type segment_name: str + + :return: True if segment exist. False otherwise. + :rtype: bool + """ + return await self._segment_storage.get(segment_name) != None + + +class LocalSegmentSynchronizerBase(object): + """Localhost mode segment base synchronizer.""" + + _DEFAULT_SEGMENT_TILL = -1 + + def _sanitize_segment(self, parsed): + """ + Sanitize json elements. + + :param parsed: segment dict + :type parsed: Dict + + :return: sanitized segment structure dict + :rtype: Dict + """ + if 'name' not in parsed or parsed['name'] is None: + _LOGGER.warning("Segment does not have [name] element, skipping") + raise Exception("Segment does not have [name] element") + if parsed['name'].strip() == '': + _LOGGER.warning("Segment [name] element is blank, skipping") + raise Exception("Segment [name] element is blank") + + for element in [('till', -1, -1, None, None, [0]), + ('added', [], None, None, None, None), + ('removed', [], None, None, None, None) + ]: + parsed = util._sanitize_object_element(parsed, 'segment', element[0], element[1], lower_value=element[2], upper_value=element[3], in_list=None, not_in_list=element[5]) + parsed = util._sanitize_object_element(parsed, 'segment', 'since', parsed['till'], -1, parsed['till'], None, [0]) + + return parsed + + +class LocalSegmentSynchronizer(LocalSegmentSynchronizerBase): + """Localhost mode segment synchronizer.""" + + def __init__(self, segment_folder, feature_flag_storage, segment_storage): + """ + Class constructor. + + :param segment_folder: patch to the segment folder + :type segment_folder: str + + :param feature_flag_storage: Feature flag Storage. + :type feature_flag_storage: splitio.storage.InMemorySplitStorage + + :param segment_storage: Segment storage reference. + :type segment_storage: splitio.storage.SegmentStorage + + """ + self._segment_folder = segment_folder + self._feature_flag_storage = feature_flag_storage + self._segment_storage = segment_storage + self._segment_sha = {} + + def synchronize_segments(self, segment_names = None): + """ + Loop through given segment names and synchronize each one. + + :param segment_names: Optional, array of segment names to update. + :type segment_name: {str} + + :return: True if no error occurs. False otherwise. + :rtype: bool + """ + _LOGGER.info('Synchronizing segments now.') + if segment_names is None: + segment_names = self._feature_flag_storage.get_segment_names() + + return_flag = True + for segment_name in segment_names: + if not self.synchronize_segment(segment_name): + return_flag = False + + return return_flag + + def synchronize_segment(self, segment_name, till=None): + """ + Update a segment from queue + + :param segment_name: Name of the segment to update. + :type segment_name: str + + :param till: ChangeNumber received. + :type till: int + + :return: True if no error occurs. False otherwise. + :rtype: bool + """ + try: + fetched = self._read_segment_from_json_file(segment_name) + fetched_sha = util._get_sha(json.dumps(fetched)) + if not self.segment_exist_in_storage(segment_name): + self._segment_sha[segment_name] = fetched_sha + self._segment_storage.put(segments.from_raw(fetched)) + _LOGGER.debug("segment %s is added to storage", segment_name) + return True + + if fetched_sha == self._segment_sha[segment_name]: + return True + + self._segment_sha[segment_name] = fetched_sha + if self._segment_storage.get_change_number(segment_name) > fetched['till'] and fetched['till'] != self._DEFAULT_SEGMENT_TILL: + return True + + self._segment_storage.update(segment_name, fetched['added'], fetched['removed'], fetched['till']) + _LOGGER.debug("segment %s is updated", segment_name) + except Exception as e: + _LOGGER.error("Could not fetch segment: %s \n" + str(e), segment_name) + return False + + return True + + def _read_segment_from_json_file(self, filename): + """ + Parse a segment and store in segment storage. + + :param filename: Path of the file containing Feature flag + :type filename: str. + + :return: Sanitized segment structure + :rtype: Dict + """ + try: + with open(os.path.join(self._segment_folder, "%s.json" % filename), 'r') as flo: + parsed = json.load(flo) + santitized_segment = self._sanitize_segment(parsed) + return santitized_segment + except Exception as exc: + raise ValueError("Error parsing file %s. Make sure it's readable." % filename) from exc + + def segment_exist_in_storage(self, segment_name): + """ + Check if a segment exists in the storage + + :param segment_name: Name of the segment + :type segment_name: str + + :return: True if segment exist. False otherwise. + :rtype: bool + """ + return self._segment_storage.get(segment_name) != None + + +class LocalSegmentSynchronizerAsync(LocalSegmentSynchronizerBase): + """Localhost mode segment async synchronizer.""" + + def __init__(self, segment_folder, feature_flag_storage, segment_storage): + """ + Class constructor. + + :param segment_folder: patch to the segment folder + :type segment_folder: str + + :param feature_flag_storage: Feature flag Storage. + :type feature_flag_storage: splitio.storage.InMemorySplitStorage + + :param segment_storage: Segment storage reference. + :type segment_storage: splitio.storage.SegmentStorage + + """ + self._segment_folder = segment_folder + self._feature_flag_storage = feature_flag_storage + self._segment_storage = segment_storage + self._segment_sha = {} + + async def synchronize_segments(self, segment_names = None): + """ + Loop through given segment names and synchronize each one. + + :param segment_names: Optional, array of segment names to update. + :type segment_name: {str} + + :return: True if no error occurs. False otherwise. + :rtype: bool + """ + _LOGGER.info('Synchronizing segments now.') + if segment_names is None: + segment_names = await self._feature_flag_storage.get_segment_names() + + return_flag = True + for segment_name in segment_names: + if not await self.synchronize_segment(segment_name): + return_flag = False + + return return_flag + + async def synchronize_segment(self, segment_name, till=None): + """ + Update a segment from queue + + :param segment_name: Name of the segment to update. + :type segment_name: str + + :param till: ChangeNumber received. + :type till: int + + :return: True if no error occurs. False otherwise. + :rtype: bool + """ + try: + fetched = await self._read_segment_from_json_file(segment_name) + fetched_sha = util._get_sha(json.dumps(fetched)) + if not await self.segment_exist_in_storage(segment_name): + self._segment_sha[segment_name] = fetched_sha + await self._segment_storage.put(segments.from_raw(fetched)) + _LOGGER.debug("segment %s is added to storage", segment_name) + return True + + if fetched_sha == self._segment_sha[segment_name]: + return True + + self._segment_sha[segment_name] = fetched_sha + if await self._segment_storage.get_change_number(segment_name) > fetched['till'] and fetched['till'] != self._DEFAULT_SEGMENT_TILL: + return True + + await self._segment_storage.update(segment_name, fetched['added'], fetched['removed'], fetched['till']) + _LOGGER.debug("segment %s is updated", segment_name) + except Exception as e: + _LOGGER.error("Could not fetch segment: %s \n" + str(e), segment_name) + return False + + return True + + async def _read_segment_from_json_file(self, filename): + """ + Parse a segment and store in segment storage. + + :param filename: Path of the file containing Feature flag + :type filename: str. + + :return: Sanitized segment structure + :rtype: Dict + """ + try: + async with aiofiles.open(os.path.join(self._segment_folder, "%s.json" % filename), 'r') as flo: + parsed = json.loads(await flo.read()) + santitized_segment = self._sanitize_segment(parsed) + return santitized_segment + except Exception as exc: + raise ValueError("Error parsing file %s. Make sure it's readable." % filename) from exc + + async def segment_exist_in_storage(self, segment_name): + """ + Check if a segment exists in the storage + + :param segment_name: Name of the segment + :type segment_name: str + + :return: True if segment exist. False otherwise. + :rtype: bool + """ + return await self._segment_storage.get(segment_name) != None diff --git a/splitio/sync/split.py b/splitio/sync/split.py index 5331f556..c1b5aa39 100644 --- a/splitio/sync/split.py +++ b/splitio/sync/split.py @@ -4,12 +4,20 @@ import itertools import yaml import time +import json +from enum import Enum -from splitio.api import APIException +from splitio.api import APIException, APIUriException from splitio.api.commons import FetchOptions -from splitio.models import splits +from splitio.client.input_validator import validate_flag_sets +from splitio.models import splits, rule_based_segments from splitio.util.backoff import Backoff - +from splitio.util.time import get_current_epoch_time_ms +from splitio.util.storage_helper import update_feature_flag_storage, update_feature_flag_storage_async, \ + update_rule_based_segment_storage, update_rule_based_segment_storage_async + +from splitio.sync import util +from splitio.optional.loaders import asyncio, aiofiles _LEGACY_COMMENT_LINE_RE = re.compile(r'^#.*$') _LEGACY_DEFINITION_LINE_RE = re.compile(r'^(?[\w_-]+)\s+(?P[\w_-]+)$') @@ -19,158 +27,410 @@ _ON_DEMAND_FETCH_BACKOFF_BASE = 10 # backoff base starting at 10 seconds -_ON_DEMAND_FETCH_BACKOFF_MAX_WAIT = 60 # don't sleep for more than 1 minute +_ON_DEMAND_FETCH_BACKOFF_MAX_WAIT = 30 # don't sleep for more than 30 seconds _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES = 10 -class SplitSynchronizer(object): - """Split changes synchronizer.""" +class SplitSynchronizerBase(object): + """Feature Flag changes synchronizer.""" - def __init__(self, split_api, split_storage): + def __init__(self, feature_flag_api, feature_flag_storage, rule_based_segment_storage): """ Class constructor. - :param split_api: Split API Client. - :type split_api: splitio.api.splits.SplitsAPI + :param feature_flag_api: Feature Flag API Client. + :type feature_flag_api: splitio.api.splits.SplitsAPI + + :param feature_flag_storage: Feature Flag Storage. + :type feature_flag_storage: splitio.storage.InMemorySplitStorage - :param split_storage: Split Storage. - :type split_storage: splitio.storage.InMemorySplitStorage + :param rule_based_segment_storage: Rule based segment Storage. + :type rule_based_segment_storage: splitio.storage.InMemoryRuleBasedStorage """ - self._api = split_api - self._split_storage = split_storage + self._api = feature_flag_api + self._feature_flag_storage = feature_flag_storage + self._rule_based_segment_storage = rule_based_segment_storage self._backoff = Backoff( _ON_DEMAND_FETCH_BACKOFF_BASE, _ON_DEMAND_FETCH_BACKOFF_MAX_WAIT) - def _fetch_until(self, fetch_options, till=None): + @property + def feature_flag_storage(self): + """Return Feature_flag storage object""" + return self._feature_flag_storage + + @property + def rule_based_segment_storage(self): + """Return rule base segment storage object""" + return self._rule_based_segment_storage + + def _get_config_sets(self): + """ + Get all filter flag sets cnverrted to string, if no filter flagsets exist return None + :return: string with flagsets + :rtype: str + """ + if self._feature_flag_storage.flag_set_filter.flag_sets == set({}): + return None + + return ','.join(self._feature_flag_storage.flag_set_filter.sorted_flag_sets) + + def _check_exit_conditions(self, till, rbs_till, change_number, rbs_change_number): + return (till is not None and till < change_number) or (rbs_till is not None and rbs_till < rbs_change_number) + + def _check_return_conditions(self, feature_flag_changes): + return feature_flag_changes.get('ff')['t'] == feature_flag_changes.get('ff')['s'] and feature_flag_changes.get('rbs')['t'] == feature_flag_changes.get('rbs')['s'] + +class SplitSynchronizer(SplitSynchronizerBase): + """Feature Flag changes synchronizer.""" + + def __init__(self, feature_flag_api, feature_flag_storage, rule_based_segment_storage): + """ + Class constructor. + + :param feature_flag_api: Feature Flag API Client. + :type feature_flag_api: splitio.api.splits.SplitsAPI + + :param feature_flag_storage: Feature Flag Storage. + :type feature_flag_storage: splitio.storage.InMemorySplitStorage + + :param rule_based_segment_storage: Rule based segment Storage. + :type rule_based_segment_storage: splitio.storage.InMemoryRuleBasedStorage + """ + SplitSynchronizerBase.__init__(self, feature_flag_api, feature_flag_storage, rule_based_segment_storage) + + def _fetch_until(self, fetch_options, till=None, rbs_till=None): """ Hit endpoint, update storage and return when since==till. - :param fetch_options Fetch options for getting split definitions. + :param fetch_options Fetch options for getting feature flag definitions. :type fetch_options splitio.api.FetchOptions :param till: Passed till from Streaming. :type till: int + :param rbs_till: Passed rbs till from Streaming. + :type rbs_till: int + :return: last change number :rtype: int """ + segment_list = set() while True: # Fetch until since==till - change_number = self._split_storage.get_change_number() + change_number = self._feature_flag_storage.get_change_number() if change_number is None: change_number = -1 - if till is not None and till < change_number: + + rbs_change_number = self._rule_based_segment_storage.get_change_number() + if rbs_change_number is None: + rbs_change_number = -1 + + if self._check_exit_conditions(till, rbs_till, change_number, rbs_change_number): # the passed till is less than change_number, no need to perform updates - return change_number + return change_number, rbs_change_number, segment_list try: - split_changes = self._api.fetch_splits(change_number, fetch_options) + feature_flag_changes = self._api.fetch_splits(change_number, rbs_change_number, fetch_options) except APIException as exc: - _LOGGER.error('Exception raised while fetching splits') + if exc._status_code is not None and exc._status_code == 414: + _LOGGER.error('Exception caught: the amount of flag sets provided are big causing uri length error.') + _LOGGER.debug('Exception information: ', exc_info=True) + raise APIUriException("URI is too long due to FlagSets count", exc._status_code) + + _LOGGER.error('Exception raised while fetching feature flags') _LOGGER.debug('Exception information: ', exc_info=True) raise exc + + fetched_rule_based_segments = [(rule_based_segments.from_raw(rule_based_segment)) for rule_based_segment in feature_flag_changes.get('rbs').get('d', [])] + rbs_segment_list = update_rule_based_segment_storage(self._rule_based_segment_storage, fetched_rule_based_segments, feature_flag_changes.get('rbs')['t'], self._api.clear_storage) + + fetched_feature_flags = [(splits.from_raw(feature_flag)) for feature_flag in feature_flag_changes.get('ff').get('d', [])] + segment_list.update(update_feature_flag_storage(self._feature_flag_storage, fetched_feature_flags, feature_flag_changes.get('ff')['t'], self._api.clear_storage)) + segment_list.update(rbs_segment_list) + + if self._check_return_conditions(feature_flag_changes): + return feature_flag_changes.get('ff')['t'], feature_flag_changes.get('rbs')['t'], segment_list + + def _attempt_feature_flag_sync(self, fetch_options, till=None, rbs_till=None): + """ + Hit endpoint, update storage and return True if sync is complete. - for split in split_changes.get('splits', []): - if split['status'] == splits.Status.ACTIVE.value: - self._split_storage.put(splits.from_raw(split)) - else: - self._split_storage.remove(split['name']) + :param fetch_options Fetch options for getting feature flag definitions. + :type fetch_options splitio.api.FetchOptions + + :param till: Passed till from Streaming. + :type till: int + + :param rbs_till: Passed rbs till from Streaming. + :type rbs_till: int + + :return: Flags to check if it should perform bypass or operation ended + :rtype: bool, int, int + """ + self._backoff.reset() + final_segment_list = set() + remaining_attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES + while True: + remaining_attempts -= 1 + change_number, rbs_change_number, segment_list = self._fetch_until(fetch_options, till, rbs_till) + final_segment_list.update(segment_list) + if (till is None or till <= change_number) and (rbs_till is None or rbs_till <= rbs_change_number): + return True, remaining_attempts, change_number, rbs_change_number, final_segment_list + + elif remaining_attempts <= 0: + return False, remaining_attempts, change_number, rbs_change_number, final_segment_list + + how_long = self._backoff.get() + time.sleep(how_long) + + def _get_config_sets(self): + """ + Get all filter flag sets cnverrted to string, if no filter flagsets exist return None + + :return: string with flagsets + :rtype: str + """ + if self._feature_flag_storage.flag_set_filter.flag_sets == set({}): + return None + + return ','.join(self._feature_flag_storage.flag_set_filter.sorted_flag_sets) + + def synchronize_splits(self, till=None, rbs_till=None): + """ + Hit endpoint, update storage and return True if sync is complete. + + :param till: Passed till from Streaming. + :type till: int + + :param rbs_till: Passed rbs till from Streaming. + :type rbs_till: int + """ + final_segment_list = set() + fetch_options = FetchOptions(True, sets=self._get_config_sets()) # Set Cache-Control to no-cache + successful_sync, remaining_attempts, change_number, rbs_change_number, segment_list = self._attempt_feature_flag_sync(fetch_options, + till, rbs_till) + final_segment_list.update(segment_list) + attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES - remaining_attempts + if successful_sync: # succedeed sync + _LOGGER.debug('Refresh completed in %d attempts.', attempts) + return final_segment_list + + with_cdn_bypass = FetchOptions(True, change_number, rbs_change_number, sets=self._get_config_sets()) # Set flag for bypassing CDN + without_cdn_successful_sync, remaining_attempts, change_number, rbs_change_number, segment_list = self._attempt_feature_flag_sync(with_cdn_bypass, till, rbs_till) + final_segment_list.update(segment_list) + without_cdn_attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES - remaining_attempts + if without_cdn_successful_sync: + _LOGGER.debug('Refresh completed bypassing the CDN in %d attempts.', + without_cdn_attempts) + return final_segment_list + else: + _LOGGER.debug('No changes fetched after %d attempts with CDN bypassed.', + without_cdn_attempts) + + def kill_split(self, feature_flag_name, default_treatment, change_number): + """ + Local kill for feature flag. + + :param feature_flag_name: name of the feature flag to perform kill + :type feature_flag_name: str + :param default_treatment: name of the default treatment to return + :type default_treatment: str + :param change_number: change_number + :type change_number: int + """ + self._feature_flag_storage.kill_locally(feature_flag_name, default_treatment, change_number) - self._split_storage.set_change_number(split_changes['till']) - if split_changes['till'] == split_changes['since']: - return split_changes['till'] +class SplitSynchronizerAsync(SplitSynchronizerBase): + """Feature Flag changes synchronizer async.""" - def _attempt_split_sync(self, fetch_options, till=None): + def __init__(self, feature_flag_api, feature_flag_storage, rule_based_segment_storage): + """ + Class constructor. + + :param feature_flag_api: Feature Flag API Client. + :type feature_flag_api: splitio.api.splits.SplitsAPI + + :param feature_flag_storage: Feature Flag Storage. + :type feature_flag_storage: splitio.storage.InMemorySplitStorage + + :param rule_based_segment_storage: Rule based segment Storage. + :type rule_based_segment_storage: splitio.storage.InMemoryRuleBasedStorage + """ + SplitSynchronizerBase.__init__(self, feature_flag_api, feature_flag_storage, rule_based_segment_storage) + + async def _fetch_until(self, fetch_options, till=None, rbs_till=None): + """ + Hit endpoint, update storage and return when since==till. + + :param fetch_options Fetch options for getting feature flag definitions. + :type fetch_options splitio.api.FetchOptions + + :param till: Passed till from Streaming. + :type till: int + + :param rbs_till: Passed rbs till from Streaming. + :type rbs_till: int + + :return: last change number + :rtype: int + """ + segment_list = set() + while True: # Fetch until since==till + change_number = await self._feature_flag_storage.get_change_number() + if change_number is None: + change_number = -1 + + rbs_change_number = await self._rule_based_segment_storage.get_change_number() + if rbs_change_number is None: + rbs_change_number = -1 + + if self._check_exit_conditions(till, rbs_till, change_number, rbs_change_number): + # the passed till is less than change_number, no need to perform updates + return change_number, rbs_change_number, segment_list + + try: + feature_flag_changes = await self._api.fetch_splits(change_number, rbs_change_number, fetch_options) + except APIException as exc: + if exc._status_code is not None and exc._status_code == 414: + _LOGGER.error('Exception caught: the amount of flag sets provided are big causing uri length error.') + _LOGGER.debug('Exception information: ', exc_info=True) + raise APIUriException("URI is too long due to FlagSets count", exc._status_code) + + _LOGGER.error('Exception raised while fetching feature flags') + _LOGGER.debug('Exception information: ', exc_info=True) + raise exc + + fetched_rule_based_segments = [(rule_based_segments.from_raw(rule_based_segment)) for rule_based_segment in feature_flag_changes.get('rbs').get('d', [])] + rbs_segment_list = await update_rule_based_segment_storage_async(self._rule_based_segment_storage, fetched_rule_based_segments, feature_flag_changes.get('rbs')['t'], self._api.clear_storage) + + fetched_feature_flags = [(splits.from_raw(feature_flag)) for feature_flag in feature_flag_changes.get('ff').get('d', [])] + segment_list = await update_feature_flag_storage_async(self._feature_flag_storage, fetched_feature_flags, feature_flag_changes.get('ff')['t'], self._api.clear_storage) + segment_list.update(rbs_segment_list) + + if self._check_return_conditions(feature_flag_changes): + return feature_flag_changes.get('ff')['t'], feature_flag_changes.get('rbs')['t'], segment_list + + async def _attempt_feature_flag_sync(self, fetch_options, till=None, rbs_till=None): """ Hit endpoint, update storage and return True if sync is complete. - :param fetch_options Fetch options for getting split definitions. + :param fetch_options Fetch options for getting feature flag definitions. :type fetch_options splitio.api.FetchOptions :param till: Passed till from Streaming. :type till: int + :param rbs_till: Passed rbs till from Streaming. + :type rbs_till: int + :return: Flags to check if it should perform bypass or operation ended :rtype: bool, int, int """ self._backoff.reset() + final_segment_list = set() remaining_attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES while True: remaining_attempts -= 1 - change_number = self._fetch_until(fetch_options, till) - if till is None or till <= change_number: - return True, remaining_attempts, change_number + change_number, rbs_change_number, segment_list = await self._fetch_until(fetch_options, till, rbs_till) + final_segment_list.update(segment_list) + if (till is None or till <= change_number) and (rbs_till is None or rbs_till <= rbs_change_number): + return True, remaining_attempts, change_number, rbs_change_number, final_segment_list + elif remaining_attempts <= 0: - return False, remaining_attempts, change_number + return False, remaining_attempts, change_number, rbs_change_number, final_segment_list + how_long = self._backoff.get() - time.sleep(how_long) + await asyncio.sleep(how_long) - def synchronize_splits(self, till=None): + async def synchronize_splits(self, till=None, rbs_till=None): """ Hit endpoint, update storage and return True if sync is complete. :param till: Passed till from Streaming. :type till: int + + :param rbs_till: Passed rbs till from Streaming. + :type rbs_till: int """ - fetch_options = FetchOptions(True) # Set Cache-Control to no-cache - successful_sync, remaining_attempts, change_number = self._attempt_split_sync(fetch_options, - till) + final_segment_list = set() + fetch_options = FetchOptions(True, sets=self._get_config_sets()) # Set Cache-Control to no-cache + successful_sync, remaining_attempts, change_number, rbs_change_number, segment_list = await self._attempt_feature_flag_sync(fetch_options, + till, rbs_till) + final_segment_list.update(segment_list) attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES - remaining_attempts if successful_sync: # succedeed sync _LOGGER.debug('Refresh completed in %d attempts.', attempts) - return - with_cdn_bypass = FetchOptions(True, change_number) # Set flag for bypassing CDN - without_cdn_successful_sync, remaining_attempts, change_number = self._attempt_split_sync(with_cdn_bypass, till) + return final_segment_list + + with_cdn_bypass = FetchOptions(True, change_number, rbs_change_number, sets=self._get_config_sets()) # Set flag for bypassing CDN + without_cdn_successful_sync, remaining_attempts, change_number, rbs_change_number, segment_list = await self._attempt_feature_flag_sync(with_cdn_bypass, till, rbs_till) + final_segment_list.update(segment_list) without_cdn_attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES - remaining_attempts if without_cdn_successful_sync: _LOGGER.debug('Refresh completed bypassing the CDN in %d attempts.', without_cdn_attempts) - return + return final_segment_list + else: _LOGGER.debug('No changes fetched after %d attempts with CDN bypassed.', without_cdn_attempts) - def kill_split(self, split_name, default_treatment, change_number): + async def kill_split(self, feature_flag_name, default_treatment, change_number): """ - Local kill for split. + Local kill for feature flag. - :param split_name: name of the split to perform kill - :type split_name: str + :param feature_flag_name: name of the feature flag to perform kill + :type feature_flag_name: str :param default_treatment: name of the default treatment to return :type default_treatment: str :param change_number: change_number :type change_number: int """ - self._split_storage.kill_locally(split_name, default_treatment, change_number) + await self._feature_flag_storage.kill_locally(feature_flag_name, default_treatment, change_number) + + +class LocalhostMode(Enum): + """types for localhost modes""" + LEGACY = 0 + YAML = 1 + JSON = 2 +class LocalSplitSynchronizerBase(object): + """Localhost mode feature_flag base synchronizer.""" -class LocalSplitSynchronizer(object): - """Localhost mode split synchronizer.""" + _DEFAULT_FEATURE_FLAG_TILL = -1 + _DEFAULT_RB_SEGMENT_TILL = -1 - def __init__(self, filename, split_storage): + def __init__(self, filename, feature_flag_storage, rule_based_segment_storage, localhost_mode=LocalhostMode.LEGACY): """ Class constructor. - :param filename: File to parse splits from. + :param filename: File to parse feature flags from. :type filename: str - :param split_storage: Split Storage. - :type split_storage: splitio.storage.InMemorySplitStorage + :param feature_flag_storage: Feature flag Storage. + :type feature_flag_storage: splitio.storage.InMemorySplitStorage + :param localhost_mode: mode for localhost either JSON, YAML or LEGACY. + :type localhost_mode: splitio.sync.split.LocalhostMode """ self._filename = filename - self._split_storage = split_storage + self._feature_flag_storage = feature_flag_storage + self._rule_based_segment_storage = rule_based_segment_storage + self._localhost_mode = localhost_mode + self._current_ff_sha = "-1" + self._current_rbs_sha = "-1" @staticmethod - def _make_split(split_name, conditions, configs=None): + def _make_feature_flag(feature_flag_name, conditions, configs=None): """ - Make a split with a single all_keys matcher. + Make a Feature flag with a single all_keys matcher. - :param split_name: Name of the split. - :type split_name: str. + :param feature_flag_name: Name of the feature flag. + :type feature_flag_name: str. """ return splits.from_raw({ 'changeNumber': 123, 'trafficTypeName': 'user', - 'name': split_name, + 'name': feature_flag_name, 'trafficAllocation': 100, 'trafficAllocationSeed': 123456, 'seed': 321654, @@ -179,7 +439,8 @@ def _make_split(split_name, conditions, configs=None): 'defaultTreatment': 'control', 'algo': 2, 'conditions': conditions, - 'configurations': configs + 'configurations': configs, + 'prerequisites': [] }) @staticmethod @@ -222,16 +483,214 @@ def _make_whitelist_condition(whitelist, treatment): 'combiner': 'AND' } } + + def _sanitize_json_elements(self, parsed): + """ + Sanitize all json elements. + + :param parsed: feature flags, till and since elements dict + :type parsed: Dict + :return: sanitized structure dict + :rtype: Dict + """ + parsed = self._satitize_json_section(parsed, 'ff') + parsed = self._satitize_json_section(parsed, 'rbs') + + return parsed + + def _satitize_json_section(self, parsed, section_name): + """ + Sanitize specific json section. + + :param parsed: feature flags, till and since elements dict + :type parsed: Dict + + :return: sanitized structure dict + :rtype: Dict + """ + if section_name not in parsed: + parsed['ff'] = {"t": -1, "s": -1, "d": []} + if 'd' not in parsed[section_name]: + parsed[section_name]['d'] = [] + if 't' not in parsed[section_name] or parsed[section_name]['t'] is None or parsed[section_name]['t'] < -1: + parsed[section_name]['t'] = -1 + if 's' not in parsed[section_name] or parsed[section_name]['s'] is None or parsed[section_name]['s'] < -1 or parsed[section_name]['s'] > parsed[section_name]['t']: + parsed[section_name]['s'] = parsed[section_name]['t'] + + return parsed + + def _sanitize_feature_flag_elements(self, parsed_feature_flags): + """ + Sanitize all feature flags elements. + + :param parsed_feature_flags: feature flags array + :type parsed_feature_flags: [Dict] + + :return: sanitized structure dict + :rtype: [Dict] + """ + sanitized_feature_flags = [] + for feature_flag in parsed_feature_flags: + if 'name' not in feature_flag or feature_flag['name'].strip() == '': + _LOGGER.warning("A feature flag in json file does not have (Name) or property is empty, skipping.") + continue + for element in [('trafficTypeName', 'user', None, None, None, None), + ('trafficAllocation', 100, 0, 100, None, None), + ('trafficAllocationSeed', int(get_current_epoch_time_ms() / 1000), None, None, None, [0]), + ('seed', int(get_current_epoch_time_ms() / 1000), None, None, None, [0]), + ('status', splits.Status.ACTIVE.value, None, None, [e.value for e in splits.Status], None), + ('killed', False, None, None, None, None), + ('defaultTreatment', 'control', None, None, None, ['', ' ']), + ('changeNumber', 0, 0, None, None, None), + ('algo', 2, 2, 2, None, None)]: + feature_flag = util._sanitize_object_element(feature_flag, 'split', element[0], element[1], lower_value=element[2], upper_value=element[3], in_list=element[4], not_in_list=element[5]) + feature_flag = self._sanitize_condition(feature_flag) + if 'sets' not in feature_flag: + feature_flag['sets'] = [] + feature_flag['sets'] = validate_flag_sets(feature_flag['sets'], 'Localhost Validator') + if 'prerequisites' not in feature_flag: + feature_flag['prerequisites'] = [] + sanitized_feature_flags.append(feature_flag) + return sanitized_feature_flags + + def _sanitize_rb_segment_elements(self, parsed_rb_segments): + """ + Sanitize all rule based segments elements. + + :param parsed_rb_segments: rule based segments array + :type parsed_rb_segments: [Dict] + + :return: sanitized structure dict + :rtype: [Dict] + """ + sanitized_rb_segments = [] + for rb_segment in parsed_rb_segments: + if 'name' not in rb_segment or rb_segment['name'].strip() == '': + _LOGGER.warning("A rule based segment in json file does not have (Name) or property is empty, skipping.") + continue + + for element in [('trafficTypeName', 'user', None, None, None, None), + ('status', splits.Status.ACTIVE.value, None, None, [e.value for e in splits.Status], None), + ('changeNumber', 0, 0, None, None, None)]: + rb_segment = util._sanitize_object_element(rb_segment, 'rule based segment', element[0], element[1], lower_value=element[2], upper_value=element[3], in_list=element[4], not_in_list=element[5]) + rb_segment = self._sanitize_condition(rb_segment) + rb_segment = self._remove_partition(rb_segment) + sanitized_rb_segments.append(rb_segment) + return sanitized_rb_segments + + def _sanitize_condition(self, feature_flag): + """ + Sanitize feature flag and ensure a condition type ROLLOUT and matcher exist with ALL_KEYS elements. + + :param feature_flag: feature flag dict object + :type feature_flag: Dict + + :return: sanitized feature flag + :rtype: Dict + """ + found_all_keys_matcher = False + feature_flag['conditions'] = feature_flag.get('conditions', []) + if len(feature_flag['conditions']) > 0: + last_condition = feature_flag['conditions'][-1] + if 'conditionType' in last_condition: + if last_condition['conditionType'] == 'ROLLOUT': + if 'matcherGroup' in last_condition: + if 'matchers' in last_condition['matcherGroup']: + for matcher in last_condition['matcherGroup']['matchers']: + if matcher['matcherType'] == 'ALL_KEYS': + found_all_keys_matcher = True + break + + if not found_all_keys_matcher: + _LOGGER.debug("Missing default rule condition for feature flag: %s, adding default rule with 100%% off treatment", feature_flag['name']) + feature_flag['conditions'].append( + { + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [{ + "keySelector": { "trafficType": "user", "attribute": None }, + "matcherType": "ALL_KEYS", + "negate": False, + "userDefinedSegmentMatcherData": None, + "whitelistMatcherData": None, + "unaryNumericMatcherData": None, + "betweenMatcherData": None, + "booleanMatcherData": None, + "dependencyMatcherData": None, + "stringMatcherData": None + }] + }, + "partitions": [ + { "treatment": "on", "size": 0 }, + { "treatment": "off", "size": 100 } + ], + "label": "default rule" + }) + + return feature_flag + + def _remove_partition(self, rb_segment): + sanitized = [] + for condition in rb_segment['conditions']: + if 'partition' in condition: + del condition['partition'] + sanitized.append(condition) + rb_segment['conditions'] = sanitized + return rb_segment + + @classmethod + def _convert_yaml_to_feature_flag(cls, parsed): + grouped_by_feature_name = itertools.groupby( + sorted(parsed, key=lambda i: next(iter(i.keys()))), + lambda i: next(iter(i.keys()))) + to_return = {} + for (feature_flag_name, statements) in grouped_by_feature_name: + configs = {} + whitelist = [] + all_keys = [] + for statement in statements: + data = next(iter(statement.values())) # grab the first (and only) value. + if 'keys' in data: + keys = data['keys'] if isinstance(data['keys'], list) else [data['keys']] + whitelist.append(cls._make_whitelist_condition(keys, data['treatment'])) + else: + all_keys.append(cls._make_all_keys_condition(data['treatment'])) + if 'config' in data: + configs[data['treatment']] = data['config'] + to_return[feature_flag_name] = cls._make_feature_flag(feature_flag_name, whitelist + all_keys, configs) + return to_return + + def _check_exit_conditions(self, storage_cn, parsed_till, default_till): + if storage_cn > parsed_till and parsed_till != default_till: + return True + +class LocalSplitSynchronizer(LocalSplitSynchronizerBase): + """Localhost mode feature_flag synchronizer.""" + + def __init__(self, filename, feature_flag_storage, rule_based_segment_storage, localhost_mode=LocalhostMode.LEGACY): + """ + Class constructor. + + :param filename: File to parse feature flags from. + :type filename: str + :param feature_flag_storage: Feature flag Storage. + :type feature_flag_storage: splitio.storage.InMemorySplitStorage + :param localhost_mode: mode for localhost either JSON, YAML or LEGACY. + :type localhost_mode: splitio.sync.split.LocalhostMode + """ + LocalSplitSynchronizerBase.__init__(self, filename, feature_flag_storage, rule_based_segment_storage, localhost_mode) + @classmethod - def _read_splits_from_legacy_file(cls, filename): + def _read_feature_flags_from_legacy_file(cls, filename): """ - Parse a splits file and return a populated storage. + Parse a feature flags file and return a populated storage. - :param filename: Path of the file containing mocked splits & treatments. + :param filename: Path of the file containing mocked feature flags & treatments. :type filename: str. - :return: Storage populataed with splits ready to be evaluated. + :return: Storage populataed with feature flags ready to be evaluated. :rtype: InMemorySplitStorage """ to_return = {} @@ -244,14 +703,14 @@ def _read_splits_from_legacy_file(cls, filename): definition_match = _LEGACY_DEFINITION_LINE_RE.match(line) if not definition_match: _LOGGER.warning( - 'Invalid line on localhost environment split ' + 'Invalid line on localhost environment feature flag ' 'definition. Line = %s', line ) continue cond = cls._make_all_keys_condition(definition_match.group('treatment')) - splt = cls._make_split(definition_match.group('feature'), [cond]) + splt = cls._make_feature_flag(definition_match.group('feature'), [cond]) to_return[splt.name] = splt return to_return @@ -259,55 +718,272 @@ def _read_splits_from_legacy_file(cls, filename): raise ValueError("Error parsing file %s. Make sure it's readable." % filename) from exc @classmethod - def _read_splits_from_yaml_file(cls, filename): + def _read_feature_flags_from_yaml_file(cls, filename): """ - Parse a splits file and return a populated storage. + Parse a feature flags file and return a populated storage. - :param filename: Path of the file containing mocked splits & treatments. + :param filename: Path of the file containing mocked feature flags & treatments. :type filename: str. - :return: Storage populataed with splits ready to be evaluated. + :return: Storage populated with feature flags ready to be evaluated. :rtype: InMemorySplitStorage """ try: with open(filename, 'r') as flo: parsed = yaml.load(flo.read(), Loader=yaml.FullLoader) - grouped_by_feature_name = itertools.groupby( - sorted(parsed, key=lambda i: next(iter(i.keys()))), - lambda i: next(iter(i.keys()))) - - to_return = {} - for (split_name, statements) in grouped_by_feature_name: - configs = {} - whitelist = [] - all_keys = [] - for statement in statements: - data = next(iter(statement.values())) # grab the first (and only) value. - if 'keys' in data: - keys = data['keys'] if isinstance(data['keys'], list) else [data['keys']] - whitelist.append(cls._make_whitelist_condition(keys, data['treatment'])) - else: - all_keys.append(cls._make_all_keys_condition(data['treatment'])) - if 'config' in data: - configs[data['treatment']] = data['config'] - to_return[split_name] = cls._make_split(split_name, whitelist + all_keys, configs) + return cls._convert_yaml_to_feature_flag(parsed) + except IOError as exc: + raise ValueError("Error parsing file %s. Make sure it's readable." % filename) from exc + + def synchronize_splits(self, till=None): # pylint:disable=unused-argument + """Update feature flags in storage.""" + _LOGGER.info('Synchronizing feature flags now.') + try: + return self._synchronize_json() if self._localhost_mode == LocalhostMode.JSON else self._synchronize_legacy() + except Exception as exc: + _LOGGER.debug('Exception: ', exc_info=True) + raise APIException("Error fetching feature flags information") from exc + + def _synchronize_legacy(self): + """ + Update feature flags in storage for legacy mode. + + :return: empty array for compatibility with json mode + :rtype: [] + """ + + if self._filename.lower().endswith(('.yaml', '.yml')): + fetched = self._read_feature_flags_from_yaml_file(self._filename) + else: + fetched = self._read_feature_flags_from_legacy_file(self._filename) + to_delete = [name for name in self._feature_flag_storage.get_split_names() + if name not in fetched.keys()] + to_add = [feature_flag for feature_flag in fetched.values()] + self._feature_flag_storage.update(to_add, to_delete, 0) + return [] + + def _synchronize_json(self): + """ + Update feature flags in storage for json mode. + + :return: segment names string array + :rtype: [str] + """ + try: + parsed = self._read_feature_flags_from_json_file(self._filename) + segment_list = set() + fecthed_ff_sha = util._get_sha(json.dumps(parsed['ff'])) + fecthed_rbs_sha = util._get_sha(json.dumps(parsed['rbs'])) + + if fecthed_ff_sha == self._current_ff_sha and fecthed_rbs_sha == self._current_rbs_sha: + return [] + + self._current_ff_sha = fecthed_ff_sha + self._current_rbs_sha = fecthed_rbs_sha + + if self._check_exit_conditions(self._feature_flag_storage.get_change_number(), parsed['ff']['t'], self._DEFAULT_FEATURE_FLAG_TILL) \ + and self._check_exit_conditions(self._rule_based_segment_storage.get_change_number(), parsed['rbs']['t'], self._DEFAULT_RB_SEGMENT_TILL): + return [] + + if not self._check_exit_conditions(self._feature_flag_storage.get_change_number(), parsed['ff']['t'], self._DEFAULT_FEATURE_FLAG_TILL): + fetched_feature_flags = [splits.from_raw(feature_flag) for feature_flag in parsed['ff']['d']] + segment_list = update_feature_flag_storage(self._feature_flag_storage, fetched_feature_flags, parsed['ff']['t']) + + if not self._check_exit_conditions(self._rule_based_segment_storage.get_change_number(), parsed['rbs']['t'], self._DEFAULT_RB_SEGMENT_TILL): + fetched_rb_segments = [rule_based_segments.from_raw(rb_segment) for rb_segment in parsed['rbs']['d']] + segment_list.update(update_rule_based_segment_storage(self._rule_based_segment_storage, fetched_rb_segments, parsed['rbs']['t'])) + + return segment_list + + except Exception as exc: + _LOGGER.debug('Exception: ', exc_info=True) + raise ValueError("Error reading feature flags from json.") from exc + + def _read_feature_flags_from_json_file(self, filename): + """ + Parse a feature flags file and return a populated storage. + + :param filename: Path of the file containing feature flags + :type filename: str. + + :return: Tuple: sanitized feature flag structure dict and till + :rtype: Tuple(Dict, int) + """ + try: + with open(filename, 'r') as flo: + parsed = json.load(flo) + + # check if spec version is old + if parsed.get('splits'): + parsed = util.convert_to_new_spec(parsed) + + santitized = self._sanitize_json_elements(parsed) + santitized['ff']['d'] = self._sanitize_feature_flag_elements(santitized['ff']['d']) + santitized['rbs']['d'] = self._sanitize_rb_segment_elements(santitized['rbs']['d']) + return santitized + + except Exception as exc: + _LOGGER.debug('Exception: ', exc_info=True) + raise ValueError("Error parsing file %s. Make sure it's readable." % filename) from exc + +class LocalSplitSynchronizerAsync(LocalSplitSynchronizerBase): + """Localhost mode async feature_flag synchronizer.""" + + def __init__(self, filename, feature_flag_storage, rule_based_segment_storage, localhost_mode=LocalhostMode.LEGACY): + """ + Class constructor. + + :param filename: File to parse feature flags from. + :type filename: str + :param feature_flag_storage: Feature flag Storage. + :type feature_flag_storage: splitio.storage.InMemorySplitStorage + :param localhost_mode: mode for localhost either JSON, YAML or LEGACY. + :type localhost_mode: splitio.sync.split.LocalhostMode + """ + LocalSplitSynchronizerBase.__init__(self, filename, feature_flag_storage, rule_based_segment_storage, localhost_mode) + + @classmethod + async def _read_feature_flags_from_legacy_file(cls, filename): + """ + Parse a feature flags file and return a populated storage. + + :param filename: Path of the file containing mocked feature flags & treatments. + :type filename: str. + + :return: Storage populataed with feature flags ready to be evaluated. + :rtype: InMemorySplitStorage + """ + to_return = {} + try: + async with aiofiles.open(filename, 'r') as flo: + for line in await flo.read(): + if line.strip() == '' or _LEGACY_COMMENT_LINE_RE.match(line): + continue + + definition_match = _LEGACY_DEFINITION_LINE_RE.match(line) + if not definition_match: + _LOGGER.warning( + 'Invalid line on localhost environment feature flag ' + 'definition. Line = %s', + line + ) + continue + + cond = cls._make_all_keys_condition(definition_match.group('treatment')) + splt = cls._make_feature_flag(definition_match.group('feature'), [cond]) + to_return[splt.name] = splt return to_return except IOError as exc: raise ValueError("Error parsing file %s. Make sure it's readable." % filename) from exc - def synchronize_splits(self, till=None): # pylint:disable=unused-argument - """Update splits in storage.""" - _LOGGER.info('Synchronizing splits now.') + @classmethod + async def _read_feature_flags_from_yaml_file(cls, filename): + """ + Parse a feature flags file and return a populated storage. + + :param filename: Path of the file containing mocked feature flags & treatments. + :type filename: str. + + :return: Storage populated with feature flags ready to be evaluated. + :rtype: InMemorySplitStorage + """ + try: + async with aiofiles.open(filename, 'r') as flo: + parsed = yaml.load(await flo.read(), Loader=yaml.FullLoader) + + return cls._convert_yaml_to_feature_flag(parsed) + except IOError as exc: + raise ValueError("Error parsing file %s. Make sure it's readable." % filename) from exc + + async def synchronize_splits(self, till=None): # pylint:disable=unused-argument + """Update feature flags in storage.""" + _LOGGER.info('Synchronizing feature flags now.') + try: + return await self._synchronize_json() if self._localhost_mode == LocalhostMode.JSON else await self._synchronize_legacy() + except Exception as exc: + _LOGGER.debug('Exception: ', exc_info=True) + raise APIException("Error fetching feature flags information") from exc + + async def _synchronize_legacy(self): + """ + Update feature flags in storage for legacy mode. + + :return: empty array for compatibility with json mode + :rtype: [] + """ + if self._filename.lower().endswith(('.yaml', '.yml')): - fetched = self._read_splits_from_yaml_file(self._filename) + fetched = await self._read_feature_flags_from_yaml_file(self._filename) else: - fetched = self._read_splits_from_legacy_file(self._filename) - to_delete = [name for name in self._split_storage.get_split_names() + fetched = await self._read_feature_flags_from_legacy_file(self._filename) + to_delete = [name for name in await self._feature_flag_storage.get_split_names() if name not in fetched.keys()] - for split in fetched.values(): - self._split_storage.put(split) + to_add = [feature_flag for feature_flag in fetched.values()] + await self._feature_flag_storage.update(to_add, to_delete, 0) + + return [] + + async def _synchronize_json(self): + """ + Update feature flags in storage for json mode. + + :return: segment names string array + :rtype: [str] + """ + try: + parsed = await self._read_feature_flags_from_json_file(self._filename) + segment_list = set() + fecthed_ff_sha = util._get_sha(json.dumps(parsed['ff'])) + fecthed_rbs_sha = util._get_sha(json.dumps(parsed['rbs'])) + + if fecthed_ff_sha == self._current_ff_sha and fecthed_rbs_sha == self._current_rbs_sha: + return [] + + self._current_ff_sha = fecthed_ff_sha + self._current_rbs_sha = fecthed_rbs_sha + + if self._check_exit_conditions(await self._feature_flag_storage.get_change_number(), parsed['ff']['t'], self._DEFAULT_FEATURE_FLAG_TILL) \ + and self._check_exit_conditions(await self._rule_based_segment_storage.get_change_number(), parsed['rbs']['t'], self._DEFAULT_RB_SEGMENT_TILL): + return [] + + if not self._check_exit_conditions(await self._feature_flag_storage.get_change_number(), parsed['ff']['t'], self._DEFAULT_FEATURE_FLAG_TILL): + fetched_feature_flags = [splits.from_raw(feature_flag) for feature_flag in parsed['ff']['d']] + segment_list = await update_feature_flag_storage_async(self._feature_flag_storage, fetched_feature_flags, parsed['ff']['t']) + + if not self._check_exit_conditions(await self._rule_based_segment_storage.get_change_number(), parsed['rbs']['t'], self._DEFAULT_RB_SEGMENT_TILL): + fetched_rb_segments = [rule_based_segments.from_raw(rb_segment) for rb_segment in parsed['rbs']['d']] + segment_list.update(await update_rule_based_segment_storage_async(self._rule_based_segment_storage, fetched_rb_segments, parsed['rbs']['t'])) + + return segment_list + + except Exception as exc: + _LOGGER.debug('Exception: ', exc_info=True) + raise ValueError("Error reading feature flags from json.") from exc + + async def _read_feature_flags_from_json_file(self, filename): + """ + Parse a feature flags file and return a populated storage. + + :param filename: Path of the file containing feature flags + :type filename: str. - for split in to_delete: - self._split_storage.remove(split) + :return: Tuple: sanitized feature flag structure dict and till + :rtype: Tuple(Dict, int) + """ + try: + async with aiofiles.open(filename, 'r') as flo: + parsed = json.loads(await flo.read()) + + # check if spec version is old + if parsed.get('splits'): + parsed = util.convert_to_new_spec(parsed) + + santitized = self._sanitize_json_elements(parsed) + santitized['ff']['d'] = self._sanitize_feature_flag_elements(santitized['ff']['d']) + santitized['rbs']['d'] = self._sanitize_rb_segment_elements(santitized['rbs']['d']) + return santitized + except Exception as exc: + _LOGGER.debug('Exception: ', exc_info=True) + raise ValueError("Error parsing file %s. Make sure it's readable." % filename) from exc diff --git a/splitio/sync/synchronizer.py b/splitio/sync/synchronizer.py index 8c4fe13c..a6ca6214 100644 --- a/splitio/sync/synchronizer.py +++ b/splitio/sync/synchronizer.py @@ -3,23 +3,31 @@ import abc import logging import threading +import time +from collections import namedtuple -from splitio.api import APIException +from splitio.optional.loaders import asyncio +from splitio.api import APIException, APIUriException +from splitio.util.backoff import Backoff +from splitio.sync.split import _ON_DEMAND_FETCH_BACKOFF_BASE, _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES, _ON_DEMAND_FETCH_BACKOFF_MAX_WAIT, LocalhostMode +SplitSyncResult = namedtuple('SplitSyncResult', ['success', 'error_code']) _LOGGER = logging.getLogger(__name__) +_SYNC_ALL_NO_RETRIES = -1 + class SplitSynchronizers(object): """SplitSynchronizers.""" - def __init__(self, split_sync, segment_sync, impressions_sync, events_sync, # pylint:disable=too-many-arguments - impressions_count_sync): + def __init__(self, feature_flag_sync, segment_sync, impressions_sync, events_sync, # pylint:disable=too-many-arguments + impressions_count_sync, telemetry_sync=None, unique_keys_sync = None, clear_filter_sync = None): """ Class constructor. - :param split_sync: sync for splits - :type split_sync: splitio.sync.split.SplitSynchronizer + :param feature_flag_sync: sync for feature flags + :type feature_flag_sync: splitio.sync.split.SplitSynchronizer :param segment_sync: sync for segments :type segment_sync: splitio.sync.segment.SegmentSynchronizer :param impressions_sync: sync for impressions @@ -29,16 +37,19 @@ def __init__(self, split_sync, segment_sync, impressions_sync, events_sync, # p :param impressions_count_sync: sync for impression_counts :type impressions_count_sync: splitio.sync.impression.ImpressionsCountSynchronizer """ - self._split_sync = split_sync + self._feature_flag_sync = feature_flag_sync self._segment_sync = segment_sync self._impressions_sync = impressions_sync self._events_sync = events_sync self._impressions_count_sync = impressions_count_sync + self._unique_keys_sync = unique_keys_sync + self._clear_filter_sync = clear_filter_sync + self._telemetry_sync = telemetry_sync @property def split_sync(self): """Return split synchonizer.""" - return self._split_sync + return self._feature_flag_sync @property def segment_sync(self): @@ -60,17 +71,31 @@ def impressions_count_sync(self): """Return impressions count synchonizer.""" return self._impressions_count_sync + @property + def unique_keys_sync(self): + """Return unique keys synchonizer.""" + return self._unique_keys_sync + + @property + def clear_filter_sync(self): + """Return clear filter synchonizer.""" + return self._clear_filter_sync + + @property + def telemetry_sync(self): + """Return clear filter synchonizer.""" + return self._telemetry_sync class SplitTasks(object): """SplitTasks.""" - def __init__(self, split_task, segment_task, impressions_task, events_task, # pylint:disable=too-many-arguments - impressions_count_task): + def __init__(self, feature_flag_task, segment_task, impressions_task, events_task, # pylint:disable=too-many-arguments + impressions_count_task, telemetry_task=None, unique_keys_task = None, clear_filter_task = None, internal_events_task=None): """ Class constructor. - :param split_task: sync for splits - :type split_task: splitio.tasks.split_sync.SplitSynchronizationTask + :param feature_flag_task: sync for feature_flags + :type feature_flag_task: splitio.tasks.split_sync.SplitSynchronizationTask :param segment_task: sync for segments :type segment_task: splitio.tasks.segment_sync.SegmentSynchronizationTask :param impressions_task: sync for impressions @@ -80,16 +105,20 @@ def __init__(self, split_task, segment_task, impressions_task, events_task, # p :param impressions_count_task: sync for impression_counts :type impressions_count_task: splitio.tasks.impressions_sync.ImpressionsCountSyncTask """ - self._split_task = split_task + self._feature_flag_task = feature_flag_task self._segment_task = segment_task self._impressions_task = impressions_task self._events_task = events_task self._impressions_count_task = impressions_count_task + self._unique_keys_task = unique_keys_task + self._clear_filter_task = clear_filter_task + self._telemetry_task = telemetry_task + self._internal_events_task = internal_events_task @property def split_task(self): - """Return split sync task.""" - return self._split_task + """Return feature_flag sync task.""" + return self._feature_flag_task @property def segment_task(self): @@ -111,6 +140,25 @@ def impressions_count_task(self): """Return impressions count sync task.""" return self._impressions_count_task + @property + def unique_keys_task(self): + """Return unique keys sync task.""" + return self._unique_keys_task + + @property + def clear_filter_task(self): + """Return clear filter sync task.""" + return self._clear_filter_task + + @property + def telemetry_task(self): + """Return clear filter sync task.""" + return self._telemetry_task + + @property + def internal_events_task(self): + """Return internal events task.""" + return self._internal_events_task class BaseSynchronizer(object, metaclass=abc.ABCMeta): """Synchronizer interface.""" @@ -130,7 +178,7 @@ def synchronize_segment(self, segment_name, till): @abc.abstractmethod def synchronize_splits(self, till): """ - Synchronize all splits. + Synchronize all feature flags. :param till: to fetch :type till: int @@ -139,17 +187,17 @@ def synchronize_splits(self, till): @abc.abstractmethod def sync_all(self): - """Synchronize all split data.""" + """Synchronize all feature flag data.""" pass @abc.abstractmethod def start_periodic_fetching(self): - """Start fetchers for splits and segments.""" + """Start fetchers for feature flags and segments.""" pass @abc.abstractmethod def stop_periodic_fetching(self): - """Stop fetchers for splits and segments.""" + """Stop fetchers for feature flags and segments.""" pass @abc.abstractmethod @@ -163,12 +211,12 @@ def stop_periodic_data_recording(self, blocking): pass @abc.abstractmethod - def kill_split(self, split_name, default_treatment, change_number): + def kill_split(self, feature_flag_name, default_treatment, change_number): """ - Kill a split locally. + Kill a feature flag locally. - :param split_name: name of the split to perform kill - :type split_name: str + :param feature_flag_name: name of the feature flag to perform kill + :type feature_flag_name: str :param default_treatment: name of the default treatment to return :type default_treatment: str :param change_number: change_number @@ -187,20 +235,136 @@ def shutdown(self, blocking): pass -class Synchronizer(BaseSynchronizer): +class SynchronizerInMemoryBase(BaseSynchronizer): """Synchronizer.""" def __init__(self, split_synchronizers, split_tasks): """ Class constructor. - :param split_synchronizers: syncs for performing synchronization of segments and splits + :param split_synchronizers: syncs for performing synchronization of segments and feature flags :type split_synchronizers: splitio.sync.synchronizer.SplitSynchronizers :param split_tasks: tasks for starting/stopping tasks :type split_tasks: splitio.sync.synchronizer.SplitTasks """ + self._backoff = Backoff( + _ON_DEMAND_FETCH_BACKOFF_BASE, + _ON_DEMAND_FETCH_BACKOFF_MAX_WAIT) self._split_synchronizers = split_synchronizers self._split_tasks = split_tasks + self._periodic_data_recording_tasks = [ + self._split_tasks.impressions_task, + self._split_tasks.events_task, + self._split_tasks.telemetry_task + ] + if self._split_tasks.impressions_count_task: + self._periodic_data_recording_tasks.append(self._split_tasks.impressions_count_task) + if self._split_tasks.unique_keys_task: + self._periodic_data_recording_tasks.append(self._split_tasks.unique_keys_task) + if self._split_tasks.clear_filter_task: + self._periodic_data_recording_tasks.append(self._split_tasks.clear_filter_task) + + @property + def split_sync(self): + return self._split_synchronizers.split_sync + + @property + def segment_storage(self): + return self._split_synchronizers.segment_sync._segment_storage + + def synchronize_segment(self, segment_name, till): + """ + Synchronize particular segment. + + :param segment_name: segment associated + :type segment_name: str + :param till: to fetch + :type till: int + """ + pass + + def synchronize_splits(self, till, sync_segments=True): + """ + Synchronize all feature flags. + + :param till: to fetch + :type till: int + + :returns: whether the synchronization was successful or not. + :rtype: bool + """ + pass + + def sync_all(self, max_retry_attempts=_SYNC_ALL_NO_RETRIES): + """ + Synchronize all feature flags. + + :param max_retry_attempts: apply max attempts if it set to absilute integer. + :type max_retry_attempts: int + """ + pass + + def shutdown(self, blocking): + """ + Stop tasks. + + :param blocking:flag to wait until tasks are stopped + :type blocking: bool + """ + pass + + def start_periodic_fetching(self): + """Start fetchers for feature flags and segments.""" + _LOGGER.debug('Starting periodic data fetching') + self._split_tasks.split_task.start() + self._split_tasks.segment_task.start() + + def stop_periodic_fetching(self): + """Stop fetchers for feature flags and segments.""" + pass + + def start_periodic_data_recording(self): + """Start recorders.""" + _LOGGER.debug('Starting periodic data recording') + for task in self._periodic_data_recording_tasks: + task.start() + + def stop_periodic_data_recording(self, blocking): + """ + Stop recorders. + + :param blocking: flag to wait until tasks are stopped + :type blocking: bool + """ + pass + + def kill_split(self, feature_flag_name, default_treatment, change_number): + """ + Kill a feature flag locally. + + :param feature_flag_name: name of the feature flag to perform kill + :type feature_flag_name: str + :param default_treatment: name of the default treatment to return + :type default_treatment: str + :param change_number: change_number + :type change_number: int + """ + pass + + +class Synchronizer(SynchronizerInMemoryBase): + """Synchronizer.""" + + def __init__(self, split_synchronizers, split_tasks): + """ + Class constructor. + + :param split_synchronizers: syncs for performing synchronization of segments and feature flags + :type split_synchronizers: splitio.sync.synchronizer.SplitSynchronizers + :param split_tasks: tasks for starting/stopping tasks + :type split_tasks: splitio.sync.synchronizer.SplitTasks + """ + SynchronizerInMemoryBase.__init__(self, split_synchronizers, split_tasks) def _synchronize_segments(self): _LOGGER.debug('Starting segments synchronization') @@ -221,9 +385,9 @@ def synchronize_segment(self, segment_name, till): _LOGGER.error('Failed to sync some segments.') return success - def synchronize_splits(self, till): + def synchronize_splits(self, till, sync_segments=True): """ - Synchronize all splits. + Synchronize all feature flags. :param till: to fetch :type till: int @@ -231,36 +395,67 @@ def synchronize_splits(self, till): :returns: whether the synchronization was successful or not. :rtype: bool """ - _LOGGER.debug('Starting splits synchronization') + _LOGGER.debug('Starting feature flags synchronization') try: - self._split_synchronizers.split_sync.synchronize_splits(till) - return True - except APIException: - _LOGGER.error('Failed syncing splits') + new_segments = [] + for segment in self._split_synchronizers.split_sync.synchronize_splits(till): + if not self._split_synchronizers.segment_sync.segment_exist_in_storage(segment): + new_segments.append(segment) + if sync_segments and len(new_segments) != 0: + _LOGGER.debug('Synching Segments: %s', ','.join(new_segments)) + success = self._split_synchronizers.segment_sync.synchronize_segments(new_segments, True) + if not success: + _LOGGER.error('Failed to schedule sync one or all segment(s) below.') + _LOGGER.error(','.join(new_segments)) + else: + _LOGGER.debug('Segment sync scheduled.') + return SplitSyncResult(True, 0) + except APIUriException as exc: + _LOGGER.error('Failed syncing feature flags due to long URI') _LOGGER.debug('Error: ', exc_info=True) - return False + return SplitSyncResult(False, exc._status_code) - def sync_all(self): - """Synchronize all split data.""" - attempts = 3 - while attempts > 0: + except APIException as exc: + _LOGGER.error('Failed syncing feature flags') + _LOGGER.debug('Error: ', exc_info=True) + return SplitSyncResult(False, exc._status_code) + + def sync_all(self, max_retry_attempts=_SYNC_ALL_NO_RETRIES): + """ + Synchronize all feature flags. + + :param max_retry_attempts: apply max attempts if it set to absilute integer. + :type max_retry_attempts: int + """ + retry_attempts = 0 + while True: try: - if not self.synchronize_splits(None): - attempts -= 1 - continue + sync_result = self.synchronize_splits(None, False) + if not sync_result.success and sync_result.error_code is not None and sync_result.error_code == 414: + _LOGGER.error("URI too long exception caught, aborting retries") + break + + if not sync_result.success: + raise Exception("feature flags sync failed") + + # Only retrying feature flags, since segments may trigger too many calls. - # Only retrying splits, since segments may trigger too many calls. if not self._synchronize_segments(): _LOGGER.warning('Segments failed to synchronize.') # All is good return except Exception as exc: # pylint:disable=broad-except - attempts -= 1 _LOGGER.error("Exception caught when trying to sync all data: %s", str(exc)) _LOGGER.debug('Error: ', exc_info=True) + if max_retry_attempts != _SYNC_ALL_NO_RETRIES: + retry_attempts += 1 + if retry_attempts > max_retry_attempts: + break + how_long = self._backoff.get() + time.sleep(how_long) - _LOGGER.error("Could not correctly synchronize splits and segments after 3 attempts.") + _LOGGER.error("Could not correctly synchronize feature flags and segments after %d attempts.", retry_attempts) def shutdown(self, blocking): """ @@ -274,25 +469,12 @@ def shutdown(self, blocking): self.stop_periodic_fetching() self.stop_periodic_data_recording(blocking) - def start_periodic_fetching(self): - """Start fetchers for splits and segments.""" - _LOGGER.debug('Starting periodic data fetching') - self._split_tasks.split_task.start() - self._split_tasks.segment_task.start() - def stop_periodic_fetching(self): - """Stop fetchers for splits and segments.""" + """Stop fetchers for feature flags and segments.""" _LOGGER.debug('Stopping periodic fetching') self._split_tasks.split_task.stop() self._split_tasks.segment_task.stop() - def start_periodic_data_recording(self): - """Start recorders.""" - _LOGGER.debug('Starting periodic data recording') - self._split_tasks.impressions_task.start() - self._split_tasks.events_task.start() - self._split_tasks.impressions_count_task.start() - def stop_periodic_data_recording(self, blocking): """ Stop recorders. @@ -301,87 +483,235 @@ def stop_periodic_data_recording(self, blocking): :type blocking: bool """ _LOGGER.debug('Stopping periodic data recording') + if self._split_tasks.internal_events_task: + self._split_tasks.internal_events_task.stop() + if blocking: events = [] - for task in [self._split_tasks.impressions_task, - self._split_tasks.events_task, - self._split_tasks.impressions_count_task]: - stop_event = threading.Event() - task.stop(stop_event) - events.append(stop_event) - if all(event.wait() for event in events): + for task in self._periodic_data_recording_tasks: + if task != self._split_tasks.telemetry_task: + stop_event = threading.Event() + task.stop(stop_event) + events.append(stop_event) + all(event.wait() for event in events) + telemetry_event = threading.Event() + self._split_tasks.telemetry_task.stop(telemetry_event) + if telemetry_event.wait(): _LOGGER.debug('all tasks finished successfully.') else: - self._split_tasks.impressions_task.stop() - self._split_tasks.events_task.stop() - self._split_tasks.impressions_count_task.stop() + for task in self._periodic_data_recording_tasks: + task.stop() - def kill_split(self, split_name, default_treatment, change_number): + def kill_split(self, feature_flag_name, default_treatment, change_number): """ - Kill a split locally. + Kill a feature flag locally. - :param split_name: name of the split to perform kill - :type split_name: str + :param feature_flag_name: name of the feature flag to perform kill + :type feature_flag_name: str :param default_treatment: name of the default treatment to return :type default_treatment: str :param change_number: change_number :type change_number: int """ - self._split_synchronizers.split_sync.kill_split(split_name, default_treatment, + self._split_synchronizers.split_sync.kill_split(feature_flag_name, default_treatment, change_number) - -class LocalhostSynchronizer(BaseSynchronizer): - """LocalhostSynchronizer.""" +class SynchronizerAsync(SynchronizerInMemoryBase): + """Synchronizer async.""" def __init__(self, split_synchronizers, split_tasks): """ Class constructor. - :param split_synchronizers: syncs for performing synchronization of segments and splits + :param split_synchronizers: syncs for performing synchronization of segments and feature flags :type split_synchronizers: splitio.sync.synchronizer.SplitSynchronizers :param split_tasks: tasks for starting/stopping tasks :type split_tasks: splitio.sync.synchronizer.SplitTasks """ - self._split_synchronizers = split_synchronizers - self._split_tasks = split_tasks + SynchronizerInMemoryBase.__init__(self, split_synchronizers, split_tasks) + self._shutdown = False - def sync_all(self): - """Synchronize all split data.""" + async def _synchronize_segments(self): + _LOGGER.debug('Starting segments synchronization') + return await self._split_synchronizers.segment_sync.synchronize_segments() + + async def synchronize_segment(self, segment_name, till): + """ + Synchronize particular segment. + + :param segment_name: segment associated + :type segment_name: str + :param till: to fetch + :type till: int + """ + _LOGGER.debug('Synchronizing segment %s', segment_name) + success = await self._split_synchronizers.segment_sync.synchronize_segment(segment_name, till) + if not success: + _LOGGER.error('Failed to sync some segments.') + return success + + async def synchronize_splits(self, till, sync_segments=True): + """ + Synchronize all feature flags. + + :param till: to fetch + :type till: int + + :returns: whether the synchronization was successful or not. + :rtype: bool + """ + if self._shutdown: + return + + _LOGGER.debug('Starting feature flags synchronization') try: - self._split_synchronizers.split_sync.synchronize_splits(None) + new_segments = [] + for segment in await self._split_synchronizers.split_sync.synchronize_splits(till): + if not await self._split_synchronizers.segment_sync.segment_exist_in_storage(segment): + new_segments.append(segment) + if sync_segments and len(new_segments) != 0: + _LOGGER.debug('Synching Segments: %s', ','.join(new_segments)) + success = await self._split_synchronizers.segment_sync.synchronize_segments(new_segments, True) + if not success: + _LOGGER.error('Failed to schedule sync one or all segment(s) below.') + _LOGGER.error(','.join(new_segments)) + else: + _LOGGER.debug('Segment sync scheduled.') + return SplitSyncResult(True, 0) + except APIUriException as exc: + _LOGGER.error('Failed syncing feature flags due to long URI') + _LOGGER.debug('Error: ', exc_info=True) + return SplitSyncResult(False, exc._status_code) + except APIException as exc: - _LOGGER.error('Failed syncing splits') - raise APIException('Failed to sync splits') from exc + _LOGGER.error('Failed syncing feature flags') + _LOGGER.debug('Error: ', exc_info=True) + return SplitSyncResult(False, exc._status_code) - def start_periodic_fetching(self): - """Start fetchers for splits and segments.""" - _LOGGER.debug('Starting periodic data fetching') - self._split_tasks.split_task.start() + async def sync_all(self, max_retry_attempts=_SYNC_ALL_NO_RETRIES): + """ + Synchronize all feature flags. - def stop_periodic_fetching(self): - """Stop fetchers for splits and segments.""" + :param max_retry_attempts: apply max attempts if it set to absilute integer. + :type max_retry_attempts: int + """ + self._shutdown = False + retry_attempts = 0 + while not self._shutdown: + try: + sync_result = await self.synchronize_splits(None, False) + if not sync_result.success and sync_result.error_code is not None and sync_result.error_code == 414: + _LOGGER.error("URI too long exception caught, aborting retries") + break + + if not sync_result.success: + raise Exception("feature flags sync failed") + + # Only retrying feature flags, since segments may trigger too many calls. + + if not await self._synchronize_segments(): + _LOGGER.warning('Segments failed to synchronize.') + + # All is good + return + except Exception as exc: # pylint:disable=broad-except + _LOGGER.error("Exception caught when trying to sync all data: %s", str(exc)) + _LOGGER.debug('Error: ', exc_info=True) + if max_retry_attempts != _SYNC_ALL_NO_RETRIES: + retry_attempts += 1 + if retry_attempts > max_retry_attempts: + break + how_long = self._backoff.get() + if not self._shutdown: + await asyncio.sleep(how_long) + + _LOGGER.error("Could not correctly synchronize feature flags and segments after %d attempts.", retry_attempts) + + async def shutdown(self, blocking): + """ + Stop tasks. + + :param blocking:flag to wait until tasks are stopped + :type blocking: bool + """ + _LOGGER.debug('Shutting down tasks.') + self._shutdown = True + await self._split_synchronizers.segment_sync.shutdown() + await self.stop_periodic_fetching() + await self.stop_periodic_data_recording(blocking) + + async def stop_periodic_fetching(self): + """Stop fetchers for feature flags and segments.""" _LOGGER.debug('Stopping periodic fetching') - self._split_tasks.split_task.stop() + await self._split_tasks.split_task.stop() + await self._split_tasks.segment_task.stop() - def kill_split(self, split_name, default_treatment, change_number): - """Kill a split locally.""" - raise NotImplementedError() + async def stop_periodic_data_recording(self, blocking): + """ + Stop recorders. - def synchronize_splits(self, till): - """Synchronize all splits.""" - raise NotImplementedError() + :param blocking: flag to wait until tasks are stopped + :type blocking: bool + """ + _LOGGER.debug('Stopping periodic data recording') + if self._split_tasks.internal_events_task: + await self._split_tasks.internal_events_task.stop() + + if blocking: + await self._stop_periodic_data_recording() + _LOGGER.debug('all tasks finished successfully.') + else: + asyncio.get_running_loop().create_task(self._stop_periodic_data_recording()) - def synchronize_segment(self, segment_name, till): - """Synchronize particular segment.""" - raise NotImplementedError() + async def _stop_periodic_data_recording(self): + """ + Stop recorders. - def start_periodic_data_recording(self): - """Start recorders.""" - pass + :param blocking: flag to wait until tasks are stopped + :type blocking: bool + """ + for task in self._periodic_data_recording_tasks: + await task.stop() - def stop_periodic_data_recording(self, blocking): - """Stop recorders.""" + async def kill_split(self, feature_flag_name, default_treatment, change_number): + """ + Kill a feature flag locally. + + :param feature_flag_name: name of the feature flag to perform kill + :type feature_flag_name: str + :param default_treatment: name of the default treatment to return + :type default_treatment: str + :param change_number: change_number + :type change_number: int + """ + await self._split_synchronizers.split_sync.kill_split(feature_flag_name, default_treatment, + change_number) + +class RedisSynchronizerBase(BaseSynchronizer): + """Redis Synchronizer.""" + + def __init__(self, split_synchronizers, split_tasks): + """ + Class constructor. + + :param split_synchronizers: syncs for performing synchronization of segments and feature flags + :type split_synchronizers: splitio.sync.synchronizer.SplitSynchronizers + :param split_tasks: tasks for starting/stopping tasks + :type split_tasks: splitio.sync.synchronizer.SplitTasks + """ + self._split_synchronizers = split_synchronizers + self._tasks = [] + if split_tasks.impressions_count_task is not None: + self._tasks.append(split_tasks.impressions_count_task) + if split_tasks.unique_keys_task is not None: + self._tasks.append(split_tasks.unique_keys_task) + if split_tasks.clear_filter_task is not None: + self._tasks.append(split_tasks.clear_filter_task) + + def sync_all(self): + """ + Not implemented + """ pass def shutdown(self, blocking): @@ -391,4 +721,483 @@ def shutdown(self, blocking): :param blocking:flag to wait until tasks are stopped :type blocking: bool """ - self.stop_periodic_fetching() + pass + + def start_periodic_data_recording(self): + """Start recorders.""" + _LOGGER.debug('Starting periodic data recording') + for task in self._tasks: + task.start() + + def stop_periodic_data_recording(self, blocking): + """ + Stop recorders. + + :param blocking: flag to wait until tasks are stopped + :type blocking: bool + """ + pass + + def kill_split(self, feature_flag_name, default_treatment, change_number): + """Kill a feature flag locally.""" + raise NotImplementedError() + + def synchronize_splits(self, till): + """Synchronize all feature flags.""" + raise NotImplementedError() + + def synchronize_segment(self, segment_name, till): + """Synchronize particular segment.""" + raise NotImplementedError() + + def start_periodic_fetching(self): + """Start fetchers for feature flags and segments.""" + raise NotImplementedError() + + def stop_periodic_fetching(self): + """Stop fetchers for feature flags and segments.""" + raise NotImplementedError() + + +class RedisSynchronizer(RedisSynchronizerBase): + """Redis Synchronizer.""" + + def __init__(self, split_synchronizers, split_tasks): + """ + Class constructor. + + :param split_synchronizers: syncs for performing synchronization of segments and feature flags + :type split_synchronizers: splitio.sync.synchronizer.SplitSynchronizers + :param split_tasks: tasks for starting/stopping tasks + :type split_tasks: splitio.sync.synchronizer.SplitTasks + """ + RedisSynchronizerBase.__init__(self, split_synchronizers, split_tasks) + + def shutdown(self, blocking): + """ + Stop tasks. + + :param blocking:flag to wait until tasks are stopped + :type blocking: bool + """ + _LOGGER.debug('Shutting down tasks.') + self.stop_periodic_data_recording(blocking) + + def stop_periodic_data_recording(self, blocking): + """ + Stop recorders. + + :param blocking: flag to wait until tasks are stopped + :type blocking: bool + """ + _LOGGER.debug('Stopping periodic data recording') + if blocking: + events = [] + for task in self._tasks: + stop_event = threading.Event() + task.stop(stop_event) + events.append(stop_event) + if all(event.wait() for event in events): + _LOGGER.debug('all tasks finished successfully.') + else: + for task in self._tasks: + task.stop() + + +class RedisSynchronizerAsync(RedisSynchronizerBase): + """Redis Synchronizer.""" + + def __init__(self, split_synchronizers, split_tasks): + """ + Class constructor. + + :param split_synchronizers: syncs for performing synchronization of segments and feature flags + :type split_synchronizers: splitio.sync.synchronizer.SplitSynchronizers + :param split_tasks: tasks for starting/stopping tasks + :type split_tasks: splitio.sync.synchronizer.SplitTasks + """ + RedisSynchronizerBase.__init__(self, split_synchronizers, split_tasks) + + async def shutdown(self, blocking): + """ + Stop tasks. + + :param blocking:flag to wait until tasks are stopped + :type blocking: bool + """ + _LOGGER.debug('Shutting down tasks.') + await self.stop_periodic_data_recording(blocking) + + async def _stop_periodic_data_recording(self): + """ + Stop recorders. + """ + for task in self._tasks: + await task.stop() + + async def stop_periodic_data_recording(self, blocking): + """ + Stop recorders. + + :param blocking: flag to wait until tasks are stopped + :type blocking: bool + """ + _LOGGER.debug('Stopping periodic data recording') + if blocking: + await self._stop_periodic_data_recording() + _LOGGER.debug('all tasks finished successfully.') + else: + asyncio.get_running_loop().create_task(self._stop_periodic_data_recording) + + +class LocalhostSynchronizerBase(BaseSynchronizer): + """LocalhostSynchronizer base.""" + + def __init__(self, split_synchronizers, split_tasks, localhost_mode): + """ + Class constructor. + + :param split_synchronizers: syncs for performing synchronization of segments and feature flags + :type split_synchronizers: splitio.sync.synchronizer.SplitSynchronizers + :param split_tasks: tasks for starting/stopping tasks + :type split_tasks: splitio.sync.synchronizer.SplitTasks + """ + self._split_synchronizers = split_synchronizers + self._split_tasks = split_tasks + self._localhost_mode = localhost_mode + self._backoff = Backoff( + _ON_DEMAND_FETCH_BACKOFF_BASE, + _ON_DEMAND_FETCH_BACKOFF_MAX_WAIT) + + def sync_all(self, till=None): + """ + Synchronize all feature flags. + """ + # TODO: to be removed when legacy and yaml use BUR + pass + + def start_periodic_fetching(self): + """Start fetchers for feature flags and segments.""" + if self._split_tasks.split_task is not None: + _LOGGER.debug('Starting periodic data fetching') + self._split_tasks.split_task.start() + if self._split_tasks.segment_task is not None: + self._split_tasks.segment_task.start() + + def stop_periodic_fetching(self): + """Stop fetchers for feature flags and segments.""" + pass + + def kill_split(self, split_name, default_treatment, change_number): + """Kill a feature flag locally.""" + raise NotImplementedError() + + def synchronize_splits(self): + """Synchronize all feature flags.""" + pass + + def synchronize_segment(self, segment_name, till): + """Synchronize particular segment.""" + pass + + def start_periodic_data_recording(self): + """Start recorders.""" + pass + + def stop_periodic_data_recording(self, blocking): + """Stop recorders.""" + pass + + def shutdown(self, blocking): + """ + Stop tasks. + + :param blocking:flag to wait until tasks are stopped + :type blocking: bool + """ + pass + + +class LocalhostSynchronizer(LocalhostSynchronizerBase): + """LocalhostSynchronizer.""" + + def __init__(self, split_synchronizers, split_tasks, localhost_mode): + """ + Class constructor. + + :param split_synchronizers: syncs for performing synchronization of segments and feature flags + :type split_synchronizers: splitio.sync.synchronizer.SplitSynchronizers + :param split_tasks: tasks for starting/stopping tasks + :type split_tasks: splitio.sync.synchronizer.SplitTasks + """ + LocalhostSynchronizerBase.__init__(self, split_synchronizers, split_tasks, localhost_mode) + + def sync_all(self, till=None): + """ + Synchronize all feature flags. + """ + # TODO: to be removed when legacy and yaml use BUR + if self._localhost_mode != LocalhostMode.JSON: + return self.synchronize_splits() + + self._backoff.reset() + remaining_attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES + while remaining_attempts > 0: + remaining_attempts -= 1 + try: + return self.synchronize_splits() + except APIException as exc: + _LOGGER.error('Failed syncing all') + _LOGGER.error(str(exc)) + + how_long = self._backoff.get() + time.sleep(how_long) + + def stop_periodic_fetching(self): + """Stop fetchers for feature flags and segments.""" + _LOGGER.debug('Stopping periodic fetching') + if self._split_tasks.split_task is not None: + self._split_tasks.split_task.stop() + if self._split_tasks.segment_task is not None: + self._split_tasks.segment_task.stop() + if self._split_tasks.internal_events_task: + _LOGGER.debug('Stopping internal events notification') + self._split_tasks.internal_events_task.stop() + + def synchronize_splits(self): + """Synchronize all feature flags.""" + try: + new_segments = [] + for segment in self._split_synchronizers.split_sync.synchronize_splits(): + if not self._split_synchronizers.segment_sync.segment_exist_in_storage(segment): + new_segments.append(segment) + if len(new_segments) > 0: + _LOGGER.debug('Synching Segments: %s', ','.join(new_segments)) + success = self._split_synchronizers.segment_sync.synchronize_segments(new_segments) + if not success: + _LOGGER.error('Failed to schedule sync one or all segment(s) below.') + _LOGGER.error(','.join(new_segments)) + else: + _LOGGER.debug('Segment sync scheduled.') + return True + + except APIException as exc: + _LOGGER.error('Failed syncing feature flags') + raise APIException('Failed to sync feature flags') from exc + + def shutdown(self, blocking): + """ + Stop tasks. + + :param blocking:flag to wait until tasks are stopped + :type blocking: bool + """ + self.stop_periodic_fetching() + + +class LocalhostSynchronizerAsync(LocalhostSynchronizerBase): + """LocalhostSynchronizer Async.""" + + def __init__(self, split_synchronizers, split_tasks, localhost_mode): + """ + Class constructor. + + :param split_synchronizers: syncs for performing synchronization of segments and feature flags + :type split_synchronizers: splitio.sync.synchronizer.SplitSynchronizers + :param split_tasks: tasks for starting/stopping tasks + :type split_tasks: splitio.sync.synchronizer.SplitTasks + """ + LocalhostSynchronizerBase.__init__(self, split_synchronizers, split_tasks, localhost_mode) + + async def sync_all(self, till=None): + """ + Synchronize all feature flags. + """ + # TODO: to be removed when legacy and yaml use BUR + if self._localhost_mode != LocalhostMode.JSON: + return await self.synchronize_splits() + + self._backoff.reset() + remaining_attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES + while remaining_attempts > 0: + remaining_attempts -= 1 + try: + return await self.synchronize_splits() + except APIException as exc: + _LOGGER.error('Failed syncing all') + _LOGGER.error(str(exc)) + + how_long = self._backoff.get() + await asyncio.sleep(how_long) + + async def stop_periodic_fetching(self): + """Stop fetchers for feature flags and segments.""" + _LOGGER.debug('Stopping periodic fetching') + if self._split_tasks.split_task is not None: + await self._split_tasks.split_task.stop() + if self._split_tasks.segment_task is not None: + await self._split_tasks.segment_task.stop() + if self._split_tasks.internal_events_task is not None: + _LOGGER.debug('Stopping internal events notification') + await self._split_tasks.internal_events_task.stop() + + async def synchronize_splits(self): + """Synchronize all feature flags.""" + try: + new_segments = [] + for segment in await self._split_synchronizers.split_sync.synchronize_splits(): + if not await self._split_synchronizers.segment_sync.segment_exist_in_storage(segment): + new_segments.append(segment) + if len(new_segments) > 0: + _LOGGER.debug('Synching Segments: %s', ','.join(new_segments)) + success = await self._split_synchronizers.segment_sync.synchronize_segments(new_segments) + if not success: + _LOGGER.error('Failed to schedule sync one or all segment(s) below.') + _LOGGER.error(','.join(new_segments)) + else: + _LOGGER.debug('Segment sync scheduled.') + return True + + except APIException as exc: + _LOGGER.error('Failed syncing feature flags') + raise APIException('Failed to sync feature flags') from exc + + async def shutdown(self, blocking): + """ + Stop tasks. + + :param blocking:flag to wait until tasks are stopped + :type blocking: bool + """ + await self.stop_periodic_fetching() + + +class PluggableSynchronizer(BaseSynchronizer): + """Plugable Synchronizer.""" + + def synchronize_segment(self, segment_name, till): + """ + Synchronize particular segment. + + :param segment_name: segment associated + :type segment_name: str + :param till: to fetch + :type till: int + """ + pass + + def synchronize_splits(self, till): + """ + Synchronize all feature flags. + + :param till: to fetch + :type till: int + """ + pass + + def sync_all(self): + """Synchronize all feature flag data.""" + pass + + def start_periodic_fetching(self): + """Start fetchers for feature flags and segments.""" + pass + + def stop_periodic_fetching(self): + """Stop fetchers for feature flags and segments.""" + pass + + def start_periodic_data_recording(self): + """Start recorders.""" + pass + + def stop_periodic_data_recording(self, blocking): + """Stop recorders.""" + pass + + def kill_split(self, feature_flag_name, default_treatment, change_number): + """ + Kill a feature_flag locally. + + :param feature_flag_name: name of the feature_flag to perform kill + :type feature_flag_name: str + :param default_treatment: name of the default treatment to return + :type default_treatment: str + :param change_number: change_number + :type change_number: int + """ + pass + + def shutdown(self, blocking): + """ + Stop tasks. + + :param blocking:flag to wait until tasks are stopped + :type blocking: bool + """ + pass + +class PluggableSynchronizerAsync(BaseSynchronizer): + """Plugable Synchronizer.""" + + async def synchronize_segment(self, segment_name, till): + """ + Synchronize particular segment. + + :param segment_name: segment associated + :type segment_name: str + :param till: to fetch + :type till: int + """ + pass + + async def synchronize_splits(self, till): + """ + Synchronize all feature flags. + + :param till: to fetch + :type till: int + """ + pass + + async def sync_all(self): + """Synchronize all split data.""" + pass + + async def start_periodic_fetching(self): + """Start fetchers for feature flags and segments.""" + pass + + async def stop_periodic_fetching(self): + """Stop fetchers for feature flags and segments.""" + pass + + async def start_periodic_data_recording(self): + """Start recorders.""" + pass + + async def stop_periodic_data_recording(self, blocking): + """Stop recorders.""" + pass + + async def kill_split(self, feature_flag_name, default_treatment, change_number): + """ + Kill a feature_flag locally. + + :param feature_flag_name: name of the feature flag to perform kill + :type feature_flag_name: str + :param default_treatment: name of the default treatment to return + :type default_treatment: str + :param change_number: change_number + :type change_number: int + """ + pass + + async def shutdown(self, blocking): + """ + Stop tasks. + + :param blocking:flag to wait until tasks are stopped + :type blocking: bool + """ + pass diff --git a/splitio/sync/telemetry.py b/splitio/sync/telemetry.py new file mode 100644 index 00000000..38ce7da6 --- /dev/null +++ b/splitio/sync/telemetry.py @@ -0,0 +1,172 @@ +"""Telemetry Sync Class.""" +import abc + +class TelemetrySynchronizer(object): + """Telemetry synchronizer class.""" + + def __init__(self, telemetry_submitter): + """Initialize Telemetry sync class.""" + self._telemetry_submitter = telemetry_submitter + + def synchronize_config(self): + """synchronize initial config data class.""" + self._telemetry_submitter.synchronize_config() + + def synchronize_stats(self): + """synchronize runtime stats class.""" + self._telemetry_submitter.synchronize_stats() + + +class TelemetrySynchronizerAsync(object): + """Telemetry synchronizer class.""" + + def __init__(self, telemetry_submitter): + """Initialize Telemetry sync class.""" + self._telemetry_submitter = telemetry_submitter + + async def synchronize_config(self): + """synchronize initial config data class.""" + await self._telemetry_submitter.synchronize_config() + + async def synchronize_stats(self): + """synchronize runtime stats class.""" + await self._telemetry_submitter.synchronize_stats() + + +class TelemetrySubmitter(object, metaclass=abc.ABCMeta): + """Telemetry sumbitter interface.""" + + @abc.abstractmethod + def synchronize_config(self): + """synchronize initial config data classe.""" + + @abc.abstractmethod + def synchronize_stats(self): + """synchronize runtime stats class.""" + + +class InMemoryTelemetrySubmitter(TelemetrySubmitter): + """Telemetry sumbitter class.""" + + def __init__(self, telemetry_consumer, feature_flag_storage, segment_storage, telemetry_api): + """Initialize all producer classes.""" + self._telemetry_init_consumer = telemetry_consumer.get_telemetry_init_consumer() + self._telemetry_evaluation_consumer = telemetry_consumer.get_telemetry_evaluation_consumer() + self._telemetry_runtime_consumer = telemetry_consumer.get_telemetry_runtime_consumer() + self._telemetry_api = telemetry_api + self._feature_flag_storage = feature_flag_storage + self._segment_storage = segment_storage + + def synchronize_config(self): + """synchronize initial config data classe.""" + self._telemetry_api.record_init(self._telemetry_init_consumer.get_config_stats()) + + def synchronize_stats(self): + """synchronize runtime stats class.""" + self._telemetry_api.record_stats(self._build_stats()) + + def _build_stats(self): + """ + Format stats to Dict. + + :returns: formatted stats + :rtype: Dict + """ + merged_dict = { + 'spC': self._feature_flag_storage.get_splits_count(), + 'seC': self._segment_storage.get_segments_count(), + 'skC': self._segment_storage.get_segments_keys_count() + } + merged_dict.update(self._telemetry_runtime_consumer.pop_formatted_stats()) + merged_dict.update(self._telemetry_evaluation_consumer.pop_formatted_stats()) + return merged_dict + + +class InMemoryTelemetrySubmitterAsync(TelemetrySubmitter): + """Telemetry sumbitter async class.""" + + def __init__(self, telemetry_consumer, feature_flag_storage, segment_storage, telemetry_api): + """Initialize all producer classes.""" + self._telemetry_init_consumer = telemetry_consumer.get_telemetry_init_consumer() + self._telemetry_evaluation_consumer = telemetry_consumer.get_telemetry_evaluation_consumer() + self._telemetry_runtime_consumer = telemetry_consumer.get_telemetry_runtime_consumer() + self._telemetry_api = telemetry_api + self._feature_flag_storage = feature_flag_storage + self._segment_storage = segment_storage + + async def synchronize_config(self): + """synchronize initial config data classe.""" + await self._telemetry_api.record_init(await self._telemetry_init_consumer.get_config_stats()) + + async def synchronize_stats(self): + """synchronize runtime stats class.""" + await self._telemetry_api.record_stats(await self._build_stats()) + + async def _build_stats(self): + """ + Format stats to Dict. + + :returns: formatted stats + :rtype: Dict + """ + merged_dict = { + 'spC': await self._feature_flag_storage.get_splits_count(), + 'seC': await self._segment_storage.get_segments_count(), + 'skC': await self._segment_storage.get_segments_keys_count() + } + merged_dict.update(await self._telemetry_runtime_consumer.pop_formatted_stats()) + merged_dict.update(await self._telemetry_evaluation_consumer.pop_formatted_stats()) + return merged_dict + +class RedisTelemetrySubmitter(object): + """Telemetry sumbitter class.""" + + def __init__(self, telemetry_storage): + """Initialize all producer classes.""" + self._telemetry_storage = telemetry_storage + + def synchronize_config(self): + """synchronize initial config data classe.""" + self._telemetry_storage.push_config_stats() + + def synchronize_stats(self): + """No implementation.""" + pass + + +class RedisTelemetrySubmitterAsync(object): + """Telemetry sumbitter class.""" + + def __init__(self, telemetry_storage): + """Initialize all producer classes.""" + self._telemetry_storage = telemetry_storage + + async def synchronize_config(self): + """synchronize initial config data classe.""" + await self._telemetry_storage.push_config_stats() + + async def synchronize_stats(self): + """No implementation.""" + pass + +class LocalhostTelemetrySubmitter(object): + """Telemetry sumbitter class.""" + + def synchronize_config(self): + """No implementation.""" + pass + + def synchronize_stats(self): + """No implementation.""" + pass + +class LocalhostTelemetrySubmitterAsync(object): + """Telemetry sumbitter class.""" + + async def synchronize_config(self): + """No implementation.""" + pass + + async def synchronize_stats(self): + """No implementation.""" + pass diff --git a/splitio/sync/unique_keys.py b/splitio/sync/unique_keys.py new file mode 100644 index 00000000..b11a6084 --- /dev/null +++ b/splitio/sync/unique_keys.py @@ -0,0 +1,144 @@ +_UNIQUE_KEYS_MAX_BULK_SIZE = 5000 + +class UniqueKeysSynchronizerBase(object): + """Unique Keys Synchronizer base class.""" + + def __init__(self): + """ + Initialize Unique keys synchronizer instance + + :param uniqe_keys_tracker: instance of uniqe keys tracker + :type uniqe_keys_tracker: splitio.engine.uniqur_key_tracker.UniqueKeysTracker + """ + self._max_bulk_size = _UNIQUE_KEYS_MAX_BULK_SIZE + + def _split_cache_to_bulks(self, cache): + """ + Split the current unique keys dictionary into seperate dictionaries, + each with the size of max_bulk_size. Overflow the last feature_flag set() to new unique keys dictionary. + + :return: array of unique keys dictionaries + :rtype: [Dict{'feature_flag1': set(), 'feature_flag2': set(), .. }] + """ + bulks = [] + bulk = {} + total_size = 0 + for feature_flag in cache: + total_size += len(cache[feature_flag]) + if total_size > self._max_bulk_size: + keys_list = list(cache[feature_flag]) + chunk_list = self._chunks(keys_list) + if bulk != {}: + bulks.append(bulk) + for bulk_keys in chunk_list: + bulk[feature_flag] = set(bulk_keys) + bulks.append(bulk) + bulk = {} + else: + bulk[feature_flag] = cache[feature_flag] + if total_size != 0 and bulk != {}: + bulks.append(bulk) + + return bulks + + def _chunks(self, keys_list): + """ + Split array into chunks + """ + for i in range(0, len(keys_list), self._max_bulk_size): + yield keys_list[i:i + self._max_bulk_size] + + +class UniqueKeysSynchronizer(UniqueKeysSynchronizerBase): + """Unique Keys Synchronizer class.""" + + def __init__(self, impressions_sender_adapter, uniqe_keys_tracker): + """ + Initialize Unique keys synchronizer instance + + :param uniqe_keys_tracker: instance of uniqe keys tracker + :type uniqe_keys_tracker: splitio.engine.uniqur_key_tracker.UniqueKeysTracker + """ + UniqueKeysSynchronizerBase.__init__(self) + self._uniqe_keys_tracker = uniqe_keys_tracker + self._impressions_sender_adapter = impressions_sender_adapter + + def send_all(self): + """ + Flush the unique keys dictionary to split back end. + Limit each post to the max_bulk_size value. + + """ + cache, cache_size = self._uniqe_keys_tracker.get_cache_info_and_pop_all() + if cache_size <= self._max_bulk_size: + self._impressions_sender_adapter.record_unique_keys(cache) + else: + for bulk in self._split_cache_to_bulks(cache): + self._impressions_sender_adapter.record_unique_keys(bulk) + + +class UniqueKeysSynchronizerAsync(UniqueKeysSynchronizerBase): + """Unique Keys Synchronizer async class.""" + + def __init__(self, impressions_sender_adapter, uniqe_keys_tracker): + """ + Initialize Unique keys synchronizer instance + + :param uniqe_keys_tracker: instance of uniqe keys tracker + :type uniqe_keys_tracker: splitio.engine.uniqur_key_tracker.UniqueKeysTracker + """ + UniqueKeysSynchronizerBase.__init__(self) + self._uniqe_keys_tracker = uniqe_keys_tracker + self._impressions_sender_adapter = impressions_sender_adapter + + async def send_all(self): + """ + Flush the unique keys dictionary to split back end. + Limit each post to the max_bulk_size value. + + """ + cache, cache_size = await self._uniqe_keys_tracker.get_cache_info_and_pop_all() + if cache_size <= self._max_bulk_size: + await self._impressions_sender_adapter.record_unique_keys(cache) + else: + for bulk in self._split_cache_to_bulks(cache): + await self._impressions_sender_adapter.record_unique_keys(bulk) + + +class ClearFilterSynchronizer(object): + """Clear filter class.""" + + def __init__(self, unique_keys_tracker): + """ + Initialize Unique keys synchronizer instance + + :param uniqe_keys_tracker: instance of uniqe keys tracker + :type uniqe_keys_tracker: splitio.engine.uniqur_key_tracker.UniqueKeysTracker + """ + self._unique_keys_tracker = unique_keys_tracker + + def clear_all(self): + """ + Clear the bloom filter cache + + """ + self._unique_keys_tracker.clear_filter() + +class ClearFilterSynchronizerAsync(object): + """Clear filter async class.""" + + def __init__(self, unique_keys_tracker): + """ + Initialize Unique keys synchronizer instance + + :param uniqe_keys_tracker: instance of uniqe keys tracker + :type uniqe_keys_tracker: splitio.engine.uniqur_key_tracker.UniqueKeysTracker + """ + self._unique_keys_tracker = unique_keys_tracker + + async def clear_all(self): + """ + Clear the bloom filter cache + + """ + await self._unique_keys_tracker.clear_filter() diff --git a/splitio/sync/util.py b/splitio/sync/util.py new file mode 100644 index 00000000..cd32d2c2 --- /dev/null +++ b/splitio/sync/util.py @@ -0,0 +1,68 @@ +import hashlib +import logging + +_LOGGER = logging.getLogger(__name__) + +def _get_sha(fetched): + """ + Return sha256 of given string. + + :param fetched: string variable + :type fetched: str + + :return: hex representation of sha256 + :rtype: str + """ + return hashlib.sha256(fetched.encode()).hexdigest() + +def _sanitize_object_element(object, object_name, element_name, default_value, lower_value=None, upper_value=None, in_list=None, not_in_list=None): + """ + Sanitize specific object element. + + :param object: split or segment dict object + :type object: Dict + :param element_name: element name + :type element_name: str + :param default_value: element default value + :type default_value: any + :param lower_value: Optional, element lower value limit + :type lower_value: any + :param upper_value: Optional, element upper value limit + :type upper_value: any + :param in_list: Optional, list of values expected in element + :type in_list: [any] + :param not_in_list: Optional, list of values not expected in element + :type not_in_list: [any] + + :return: sanitized object + :rtype: Dict + """ + if element_name not in object or object[element_name] is None: + object[element_name] = default_value + _LOGGER.debug("Sanitized element [%s] to '%s' in %s: %s.", element_name, default_value, object_name, object['name']) + if lower_value is not None and upper_value is not None: + if object[element_name] < lower_value or object[element_name] > upper_value: + object[element_name] = default_value + _LOGGER.debug("Sanitized element [%s] to '%s' in %s: %s.", element_name, default_value, object_name, object['name']) + elif lower_value is not None: + if object[element_name] < lower_value: + object[element_name] = default_value + _LOGGER.debug("Sanitized element [%s] to '%s' in %s: %s.", element_name, default_value, object_name, object['name']) + elif upper_value is not None: + if object[element_name] > upper_value: + object[element_name] = default_value + _LOGGER.debug("Sanitized element [%s] to '%s' in %s: %s.", element_name, default_value, object_name, object['name']) + if in_list is not None: + if object[element_name] not in in_list: + object[element_name] = default_value + _LOGGER.debug("Sanitized element [%s] to '%s' in %s: %s.", element_name, default_value, object_name, object['name']) + if not_in_list is not None: + if object[element_name] in not_in_list: + object[element_name] = default_value + _LOGGER.debug("Sanitized element [%s] to '%s' in %s: %s.", element_name, default_value, object_name, object['name']) + + return object + +def convert_to_new_spec(body): + return {"ff": {"d": body["splits"], "s": body["since"], "t": body["till"]}, + "rbs": {"d": [], "s": -1, "t": -1}} diff --git a/splitio/tasks/events_sync.py b/splitio/tasks/events_sync.py index bddcfd2c..a9b9f255 100644 --- a/splitio/tasks/events_sync.py +++ b/splitio/tasks/events_sync.py @@ -2,13 +2,39 @@ import logging from splitio.tasks import BaseSynchronizationTask -from splitio.tasks.util.asynctask import AsyncTask +from splitio.tasks.util.asynctask import AsyncTask, AsyncTaskAsync _LOGGER = logging.getLogger(__name__) -class EventsSyncTask(BaseSynchronizationTask): +class EventsSyncTaskBase(BaseSynchronizationTask): + """Events synchronization task base uses an asynctask.AsyncTask to send events.""" + + def start(self): + """Start executing the events synchronization task.""" + self._task.start() + + def stop(self, event=None): + """Stop executing the events synchronization task.""" + pass + + def flush(self): + """Flush events in storage.""" + _LOGGER.debug('Forcing flush execution for events') + self._task.force_execution() + + def is_running(self): + """ + Return whether the task is running or not. + + :return: True if the task is running. False otherwise. + :rtype: bool + """ + return self._task.running() + + +class EventsSyncTask(EventsSyncTaskBase): """Events synchronization task uses an asynctask.AsyncTask to send events.""" def __init__(self, synchronize_events, period): @@ -24,24 +50,27 @@ def __init__(self, synchronize_events, period): self._period = period self._task = AsyncTask(synchronize_events, self._period, on_stop=synchronize_events) - def start(self): - """Start executing the events synchronization task.""" - self._task.start() - def stop(self, event=None): """Stop executing the events synchronization task.""" self._task.stop(event) - def flush(self): - """Flush events in storage.""" - _LOGGER.debug('Forcing flush execution for events') - self._task.force_execution() - def is_running(self): +class EventsSyncTaskAsync(EventsSyncTaskBase): + """Events synchronization task uses an asynctask.AsyncTaskAsync to send events.""" + + def __init__(self, synchronize_events, period): """ - Return whether the task is running or not. + Class constructor. + + :param synchronize_events: Events Api object to send data to the backend + :type synchronize_events: splitio.api.events.EventsAPIAsync + :param period: How many seconds to wait between subsequent event pushes to the BE. + :type period: int - :return: True if the task is running. False otherwise. - :rtype: bool """ - return self._task.running() + self._period = period + self._task = AsyncTaskAsync(synchronize_events, self._period, on_stop=synchronize_events) + + async def stop(self, event=None): + """Stop executing the events synchronization task.""" + await self._task.stop(True) diff --git a/splitio/tasks/impressions_sync.py b/splitio/tasks/impressions_sync.py index bfcc8993..195bdbdf 100644 --- a/splitio/tasks/impressions_sync.py +++ b/splitio/tasks/impressions_sync.py @@ -2,13 +2,39 @@ import logging from splitio.tasks import BaseSynchronizationTask -from splitio.tasks.util.asynctask import AsyncTask +from splitio.tasks.util.asynctask import AsyncTask, AsyncTaskAsync _LOGGER = logging.getLogger(__name__) -class ImpressionsSyncTask(BaseSynchronizationTask): +class ImpressionsSyncTaskBase(BaseSynchronizationTask): + """Impressions synchronization task uses an asynctask.AsyncTask to send impressions.""" + + def start(self): + """Start executing the impressions synchronization task.""" + self._task.start() + + def stop(self, event=None): + """Stop executing the impressions synchronization task.""" + pass + + def is_running(self): + """ + Return whether the task is running or not. + + :return: True if the task is running. False otherwise. + :rtype: bool + """ + return self._task.running() + + def flush(self): + """Flush impressions in storage.""" + _LOGGER.debug('Forcing flush execution for impressions') + self._task.force_execution() + + +class ImpressionsSyncTask(ImpressionsSyncTaskBase): """Impressions synchronization task uses an asynctask.AsyncTask to send impressions.""" def __init__(self, synchronize_impressions, period): @@ -25,13 +51,45 @@ def __init__(self, synchronize_impressions, period): self._task = AsyncTask(synchronize_impressions, self._period, on_stop=synchronize_impressions) + def stop(self, event=None): + """Stop executing the impressions synchronization task.""" + self._task.stop(event) + + +class ImpressionsSyncTaskAsync(ImpressionsSyncTaskBase): + """Impressions synchronization task uses an asynctask.AsyncTask to send impressions.""" + + def __init__(self, synchronize_impressions, period): + """ + Class constructor. + + :param synchronize_impressions: sender + :type synchronize_impressions: func + :param period: How many seconds to wait between subsequent impressions pushes to the BE. + :type period: int + + """ + self._period = period + self._task = AsyncTaskAsync(synchronize_impressions, self._period, + on_stop=synchronize_impressions) + + async def stop(self, event=None): + """Stop executing the impressions synchronization task.""" + await self._task.stop(True) + + +class ImpressionsCountSyncTaskBase(BaseSynchronizationTask): + """Impressions synchronization task uses an asynctask.AsyncTask to send impressions.""" + + _PERIOD = 1800 # 30 * 60 # 30 minutes + def start(self): """Start executing the impressions synchronization task.""" self._task.start() def stop(self, event=None): """Stop executing the impressions synchronization task.""" - self._task.stop(event) + pass def is_running(self): """ @@ -44,15 +102,12 @@ def is_running(self): def flush(self): """Flush impressions in storage.""" - _LOGGER.debug('Forcing flush execution for impressions') self._task.force_execution() -class ImpressionsCountSyncTask(BaseSynchronizationTask): +class ImpressionsCountSyncTask(ImpressionsCountSyncTaskBase): """Impressions synchronization task uses an asynctask.AsyncTask to send impressions.""" - _PERIOD = 1800 # 30 * 60 # 30 minutes - def __init__(self, synchronize_counters): """ Class constructor. @@ -63,23 +118,24 @@ def __init__(self, synchronize_counters): """ self._task = AsyncTask(synchronize_counters, self._PERIOD, on_stop=synchronize_counters) - def start(self): - """Start executing the impressions synchronization task.""" - self._task.start() - def stop(self, event=None): """Stop executing the impressions synchronization task.""" self._task.stop(event) - def is_running(self): + +class ImpressionsCountSyncTaskAsync(ImpressionsCountSyncTaskBase): + """Impressions synchronization task uses an asynctask.AsyncTask to send impressions.""" + + def __init__(self, synchronize_counters): """ - Return whether the task is running or not. + Class constructor. + + :param synchronize_counters: Handler + :type synchronize_counters: func - :return: True if the task is running. False otherwise. - :rtype: bool """ - return self._task.running() + self._task = AsyncTaskAsync(synchronize_counters, self._PERIOD, on_stop=synchronize_counters) - def flush(self): - """Flush impressions in storage.""" - self._task.force_execution() + async def stop(self): + """Stop executing the impressions synchronization task.""" + await self._task.stop(True) diff --git a/splitio/tasks/segment_sync.py b/splitio/tasks/segment_sync.py index 5297ce9f..55238634 100644 --- a/splitio/tasks/segment_sync.py +++ b/splitio/tasks/segment_sync.py @@ -8,7 +8,28 @@ _LOGGER = logging.getLogger(__name__) -class SegmentSynchronizationTask(BaseSynchronizationTask): +class SegmentSynchronizationTaskBase(BaseSynchronizationTask): + """Segment Syncrhonization base class.""" + + def start(self): + """Start segment synchronization.""" + self._task.start() + + def stop(self, event=None): + """Stop segment synchronization.""" + pass + + def is_running(self): + """ + Return whether the task is running or not. + + :return: True if the task is running. False otherwise. + :rtype: bool + """ + return self._task.running() + + +class SegmentSynchronizationTask(SegmentSynchronizationTaskBase): """Segment Syncrhonization class.""" def __init__(self, synchronize_segments, period): @@ -21,19 +42,24 @@ def __init__(self, synchronize_segments, period): """ self._task = asynctask.AsyncTask(synchronize_segments, period, on_init=None) - def start(self): - """Start segment synchronization.""" - self._task.start() - def stop(self, event=None): """Stop segment synchronization.""" self._task.stop(event) - def is_running(self): + +class SegmentSynchronizationTaskAsync(SegmentSynchronizationTaskBase): + """Segment Syncrhonization async class.""" + + def __init__(self, synchronize_segments, period): """ - Return whether the task is running or not. + Clas constructor. + + :param synchronize_segments: handler for syncing segments + :type synchronize_segments: func - :return: True if the task is running. False otherwise. - :rtype: bool """ - return self._task.running() + self._task = asynctask.AsyncTaskAsync(synchronize_segments, period, on_init=None) + + async def stop(self): + """Stop segment synchronization.""" + await self._task.stop(True) diff --git a/splitio/tasks/split_sync.py b/splitio/tasks/split_sync.py index 93aae875..0752bdbc 100644 --- a/splitio/tasks/split_sync.py +++ b/splitio/tasks/split_sync.py @@ -2,14 +2,36 @@ import logging from splitio.tasks import BaseSynchronizationTask -from splitio.tasks.util.asynctask import AsyncTask +from splitio.tasks.util.asynctask import AsyncTask, AsyncTaskAsync _LOGGER = logging.getLogger(__name__) -class SplitSynchronizationTask(BaseSynchronizationTask): +class SplitSynchronizationTaskBase(BaseSynchronizationTask): """Split Synchronization task class.""" + + def start(self): + """Start the task.""" + self._task.start() + + def stop(self, event=None): + """Stop the task. Accept an optional event to set when the task has finished.""" + pass + + def is_running(self): + """ + Return whether the task is running. + + :return: True if the task is running. False otherwise. + :rtype bool + """ + return self._task.running() + + +class SplitSynchronizationTask(SplitSynchronizationTaskBase): + """Split Synchronization task class.""" + def __init__(self, synchronize_splits, period): """ Class constructor. @@ -22,19 +44,26 @@ def __init__(self, synchronize_splits, period): self._period = period self._task = AsyncTask(synchronize_splits, period, on_init=None) - def start(self): - """Start the task.""" - self._task.start() - def stop(self, event=None): """Stop the task. Accept an optional event to set when the task has finished.""" self._task.stop(event) - def is_running(self): + +class SplitSynchronizationTaskAsync(SplitSynchronizationTaskBase): + """Split Synchronization async task class.""" + + def __init__(self, synchronize_splits, period): """ - Return whether the task is running. + Class constructor. - :return: True if the task is running. False otherwise. - :rtype bool + :param synchronize_splits: Handler + :type synchronize_splits: func + :param period: Period of task + :type period: int """ - return self._task.running() + self._period = period + self._task = AsyncTaskAsync(synchronize_splits, period, on_init=None) + + async def stop(self, event=None): + """Stop the task. Accept an optional event to set when the task has finished.""" + await self._task.stop(True) diff --git a/splitio/tasks/telemetry_sync.py b/splitio/tasks/telemetry_sync.py new file mode 100644 index 00000000..8545530c --- /dev/null +++ b/splitio/tasks/telemetry_sync.py @@ -0,0 +1,74 @@ +"""Telemetry syncrhonization task.""" +import logging + +from splitio.tasks import BaseSynchronizationTask +from splitio.tasks.util.asynctask import AsyncTask, AsyncTaskAsync + +_LOGGER = logging.getLogger(__name__) + +class TelemetrySyncTaskBase(BaseSynchronizationTask): + """Telemetry synchronization task uses an asynctask.AsyncTask to send MTKs.""" + + def start(self): + """Start executing the telemetry synchronization task.""" + self._task.start() + + def stop(self, event=None): + """Stop executing the unique telemetry synchronization task.""" + pass + + def is_running(self): + """ + Return whether the task is running or not. + + :return: True if the task is running. False otherwise. + :rtype: bool + """ + return self._task.running() + + def flush(self): + """Flush unique keys.""" + _LOGGER.debug('Forcing flush execution for telemetry') + self._task.force_execution() + + +class TelemetrySyncTask(TelemetrySyncTaskBase): + """Unique Telemetry task uses an asynctask.AsyncTask to send MTKs.""" + + def __init__(self, synchronize_telemetry, period): + """ + Class constructor. + + :param synchronize_telemetry: sender + :type synchronize_telemetry: func + :param period: How many seconds to wait between subsequent unique keys pushes to the BE. + :type period: int + """ + + self._task = AsyncTask(synchronize_telemetry, period, + on_stop=synchronize_telemetry) + + def stop(self, event=None): + """Stop executing the unique telemetry synchronization task.""" + self._task.stop(event) + + +class TelemetrySyncTaskAsync(TelemetrySyncTaskBase): + """Telemetry synchronization task uses an asynctask.AsyncTask to send MTKs.""" + + def __init__(self, synchronize_telemetry, period): + """ + Class constructor. + + :param synchronize_telemetry: sender + :type synchronize_telemetry: func + :param period: How many seconds to wait between subsequent unique keys pushes to the BE. + :type period: int + """ + + self._task = AsyncTaskAsync(synchronize_telemetry, period, + on_stop=synchronize_telemetry) + + async def stop(self): + """Stop executing the unique telemetry synchronization task.""" + await self._task.stop(True) diff --git a/splitio/tasks/unique_keys_sync.py b/splitio/tasks/unique_keys_sync.py new file mode 100644 index 00000000..9ba81253 --- /dev/null +++ b/splitio/tasks/unique_keys_sync.py @@ -0,0 +1,137 @@ +"""Impressions syncrhonization task.""" +import logging + +from splitio.tasks import BaseSynchronizationTask +from splitio.tasks.util.asynctask import AsyncTask, AsyncTaskAsync + + +_LOGGER = logging.getLogger(__name__) +_UNIQUE_KEYS_SYNC_PERIOD = 15 * 60 # 15 minutes +_CLEAR_FILTER_SYNC_PERIOD = 60 * 60 * 24 # 24 hours + + +class UniqueKeysSyncTaskBase(BaseSynchronizationTask): + """Unique Keys synchronization task uses an asynctask.AsyncTask to send MTKs.""" + + def start(self): + """Start executing the unique keys synchronization task.""" + self._task.start() + + def stop(self, event=None): + """Stop executing the unique keys synchronization task.""" + pass + + def is_running(self): + """ + Return whether the task is running or not. + + :return: True if the task is running. False otherwise. + :rtype: bool + """ + return self._task.running() + + def flush(self): + """Flush unique keys.""" + _LOGGER.debug('Forcing flush execution for unique keys') + self._task.force_execution() + + +class UniqueKeysSyncTask(UniqueKeysSyncTaskBase): + """Unique Keys synchronization task uses an asynctask.AsyncTask to send MTKs.""" + + def __init__(self, synchronize_unique_keys, period = _UNIQUE_KEYS_SYNC_PERIOD): + """ + Class constructor. + + :param synchronize_unique_keys: sender + :type synchronize_unique_keys: func + :param period: How many seconds to wait between subsequent unique keys pushes to the BE. + :type period: int + """ + self._task = AsyncTask(synchronize_unique_keys, period, + on_stop=synchronize_unique_keys) + + def stop(self, event=None): + """Stop executing the unique keys synchronization task.""" + self._task.stop(event) + + +class UniqueKeysSyncTaskAsync(UniqueKeysSyncTaskBase): + """Unique Keys synchronization task uses an asynctask.AsyncTask to send MTKs.""" + + def __init__(self, synchronize_unique_keys, period = _UNIQUE_KEYS_SYNC_PERIOD): + """ + Class constructor. + + :param synchronize_unique_keys: sender + :type synchronize_unique_keys: func + :param period: How many seconds to wait between subsequent unique keys pushes to the BE. + :type period: int + """ + self._task = AsyncTaskAsync(synchronize_unique_keys, period, + on_stop=synchronize_unique_keys) + + async def stop(self): + """Stop executing the unique keys synchronization task.""" + await self._task.stop(True) + + +class ClearFilterSyncTaskBase(BaseSynchronizationTask): + """Unique Keys synchronization task uses an asynctask.AsyncTask to send MTKs.""" + + def start(self): + """Start executing the unique keys synchronization task.""" + self._task.start() + + def stop(self, event=None): + """Stop executing the unique keys synchronization task.""" + pass + + def is_running(self): + """ + Return whether the task is running or not. + + :return: True if the task is running. False otherwise. + :rtype: bool + """ + return self._task.running() + + +class ClearFilterSyncTask(ClearFilterSyncTaskBase): + """Unique Keys synchronization task uses an asynctask.AsyncTask to send MTKs.""" + + def __init__(self, clear_filter, period = _CLEAR_FILTER_SYNC_PERIOD): + """ + Class constructor. + + :param synchronize_unique_keys: sender + :type synchronize_unique_keys: func + :param period: How many seconds to wait between subsequent clearing of bloom filter + :type period: int + """ + self._task = AsyncTask(clear_filter, period, + on_stop=clear_filter) + + def stop(self, event=None): + """Stop executing the unique keys synchronization task.""" + self._task.stop(event) + + +class ClearFilterSyncTaskAsync(ClearFilterSyncTaskBase): + """Unique Keys synchronization task uses an asynctask.AsyncTask to send MTKs.""" + + def __init__(self, clear_filter, period = _CLEAR_FILTER_SYNC_PERIOD): + """ + Class constructor. + + :param synchronize_unique_keys: sender + :type synchronize_unique_keys: func + :param period: How many seconds to wait between subsequent clearing of bloom filter + :type period: int + """ + self._task = AsyncTaskAsync(clear_filter, period, + on_stop=clear_filter) + + async def stop(self): + """Stop executing the unique keys synchronization task.""" + await self._task.stop(True) diff --git a/splitio/tasks/util/asynctask.py b/splitio/tasks/util/asynctask.py index 63a9f3fc..a772b2d7 100644 --- a/splitio/tasks/util/asynctask.py +++ b/splitio/tasks/util/asynctask.py @@ -2,14 +2,13 @@ import threading import logging import queue - +from splitio.optional.loaders import asyncio __TASK_STOP__ = 0 __TASK_FORCE_RUN__ = 1 _LOGGER = logging.getLogger(__name__) - def _safe_run(func): """ Execute a function wrapped in a try-except block. @@ -30,6 +29,26 @@ def _safe_run(func): _LOGGER.debug('Original traceback:', exc_info=True) return False +async def _safe_run_async(func): + """ + Execute a function wrapped in a try-except block. + + If anything goes wrong returns false instead of propagating the exception. + + :param func: Function to be executed, receives no arguments and it's return + value is ignored. + """ + try: + await func() + return True + except Exception: # pylint: disable=broad-except + # Catch any exception that might happen to avoid the periodic task + # from ending and allowing for a recovery, as well as preventing + # an exception from propagating and breaking the main thread + _LOGGER.error('Something went wrong when running passed function.') + _LOGGER.debug('Original traceback:', exc_info=True) + return False + class AsyncTask(object): # pylint: disable=too-many-instance-attributes """ @@ -94,7 +113,7 @@ def _execution_wrapper(self): _LOGGER.debug("Force execution signal received. Running now") if not _safe_run(self._main): _LOGGER.error("An error occurred when executing the task. " - "Retrying after perio expires") + "Retrying after period expires") continue except queue.Empty: # If no message was received, the timeout has expired @@ -104,7 +123,7 @@ def _execution_wrapper(self): if not _safe_run(self._main): _LOGGER.error( "An error occurred when executing the task. " - "Retrying after perio expires" + "Retrying after period expires" ) finally: self._cleanup() @@ -128,8 +147,7 @@ def start(self): # Start execution self._thread = threading.Thread(target=self._execution_wrapper, - name='AsyncTask::' + getattr(self._main, '__name__', 'N/S')) - self._thread.setDaemon(True) + name='AsyncTask::' + getattr(self._main, '__name__', 'N/S'), daemon=True) try: self._thread.start() @@ -167,3 +185,136 @@ def force_execution(self): def running(self): """Return whether the task is running or not.""" return self._running + + +class AsyncTaskAsync(object): # pylint: disable=too-many-instance-attributes + """ + Asyncrhonous controllable task async class. + + This class creates is used to wrap around a function to treat it as a + periodic task. This task can be stopped, it's execution can be forced, and + it's status (whether it's running or not) can be obtained from the task + object. + It also allows for "on init" and "on stop" functions to be passed. + """ + + + def __init__(self, main, period, on_init=None, on_stop=None): + """ + Class constructor. + + :param main: Main function to be executed periodically + :type main: callable + :param period: How many seconds to wait between executions + :type period: int + :param on_init: Function to be executed ONCE before the main one + :type on_init: callable + :param on_stop: Function to be executed ONCE after the task has finished + :type on_stop: callable + """ + self._on_init = on_init + self._main = main + self._on_stop = on_stop + self._period = period + self._messages = asyncio.Queue() + self._running = False + self._completion_event = asyncio.Event() + self._sleep_task = None + + async def _execution_wrapper(self): + """ + Execute user defined function in separate thread. + + It will execute the "on init" hook is available. If an exception is + raised it will abort execution, otherwise it will enter an infinite + loop in which the main function is executed every seconds. + After stop has been called the "on stop" hook will be invoked if + available. + + All custom functions are run within a _safe_run() function which + prevents exceptions from being propagated. + """ + try: + if self._on_init is not None: + if not await _safe_run_async(self._on_init): + _LOGGER.error("Error running task initialization function, aborting execution") + self._running = False + return + self._running = True + + while self._running: + try: + msg = await asyncio.wait_for(self._messages.get(), timeout=self._period) + if msg == __TASK_STOP__: + _LOGGER.debug("Stop signal received. finishing task execution") + break + elif msg == __TASK_FORCE_RUN__: + _LOGGER.debug("Force execution signal received. Running now") + if not await _safe_run_async(self._main): + _LOGGER.error("An error occurred when executing the task. " + "Retrying after period expires") + continue + except asyncio.QueueEmpty: + # If no message was received, the timeout has expired + # and we're ready for a new execution + pass + except asyncio.CancelledError: + break + except asyncio.TimeoutError: + pass + + if not await _safe_run_async(self._main): + _LOGGER.error( + "An error occurred when executing the task. " + "Retrying after period expires" + ) + finally: + await self._cleanup() + + async def _cleanup(self): + """Execute on_stop callback, set event if needed, update status.""" + if self._on_stop is not None: + if not await _safe_run_async(self._on_stop): + _LOGGER.error("An error occurred when executing the task's OnStop hook. ") + + self._running = False + self._completion_event.set() + _LOGGER.debug("AsyncTask finished") + + def start(self): + """Start the async task.""" + if self._running: + _LOGGER.warning("Task is already running. Ignoring .start() call") + return + # Start execution + self._completion_event.clear() + asyncio.get_running_loop().create_task(self._execution_wrapper()) + + async def stop(self, wait_for_completion=False): + """ + Send a signal to the thread in order to stop it. If the task is not running do nothing. + + Optionally accept an event to be set upon task completion. + + :param event: Event to set when the task completes. + :type event: threading.Event + """ + if not self._running: + return + + # Queue is of infinite size, should not raise an exception + self._messages.put_nowait(__TASK_STOP__) + + if wait_for_completion: + await self._completion_event.wait() + + def force_execution(self): + """Force an execution of the task without waiting for the period to end.""" + if not self._running: + return + # Queue is of infinite size, should not raise an exception + self._messages.put_nowait(__TASK_FORCE_RUN__) + + def running(self): + """Return whether the task is running or not.""" + return self._running diff --git a/splitio/tasks/util/workerpool.py b/splitio/tasks/util/workerpool.py index 32957ee6..8d6c6e53 100644 --- a/splitio/tasks/util/workerpool.py +++ b/splitio/tasks/util/workerpool.py @@ -4,10 +4,10 @@ from threading import Thread, Event import queue +from splitio.optional.loaders import asyncio _LOGGER = logging.getLogger(__name__) - class WorkerPool(object): """Worker pool class to implement single producer/multiple consumer.""" @@ -27,7 +27,7 @@ def __init__(self, worker_count, worker_func): for i in range(0, worker_count) ] for thread in self._threads: - thread.setDaemon(True) + thread.daemon = True def start(self): """Start the workers.""" @@ -116,8 +116,7 @@ def wait_for_completion(self): def stop(self, event=None): """Stop all worker nodes.""" - async_stop = Thread(target=self._wait_workers_shutdown, args=(event,)) - async_stop.setDaemon(True) + async_stop = Thread(target=self._wait_workers_shutdown, args=(event,), daemon=True) async_stop.start() def _wait_workers_shutdown(self, event): @@ -135,3 +134,96 @@ def _wait_workers_shutdown(self, event): for worker_event in self._worker_events: worker_event.wait() event.set() + + +class WorkerPoolAsync(object): + """Worker pool async class to implement single producer/multiple consumer.""" + + _abort = object() + + def __init__(self, worker_count, worker_func): + """ + Class constructor. + + :param worker_count: Number of workers for the pool. + :type worker_func: Function to be executed by the workers whenever a messages is fetched. + """ + self._semaphore = asyncio.Semaphore(worker_count) + self._queue = asyncio.Queue() + self._handler = worker_func + self._aborted = False + + async def _schedule_work(self): + """wrap the message handler execution.""" + while True: + message = await self._queue.get() + if message == self._abort: + self._aborted = True + return + asyncio.get_running_loop().create_task(self._do_work(message)) + + async def _do_work(self, message): + """process a single message.""" + try: + await self._semaphore.acquire() # wait until "there's a free worker" + if self._aborted: # check in case the pool was shutdown while we were waiting for a worker + return + await self._handler(message._message) + except Exception: + _LOGGER.error("Something went wrong when processing message %s", message) + _LOGGER.debug('Original traceback: ', exc_info=True) + message._failed = True + message._complete.set() + self._semaphore.release() # signal worker is idle + + def start(self): + """Start the workers.""" + asyncio.get_running_loop().create_task(self._schedule_work()) + + async def submit_work(self, jobs): + """ + Add a new message to the work-queue. + + :param message: New message to add. + :type message: object. + """ + self.jobs = jobs + if len(jobs) == 1: + wrapped = TaskCompletionWraper(next(i for i in jobs)) + await self._queue.put(wrapped) + return wrapped + + tasks = [TaskCompletionWraper(job) for job in jobs] + for w in tasks: + await self._queue.put(w) + + return BatchCompletionWrapper(tasks) + + async def stop(self, event=None): + """abort all execution (except currently running handlers).""" + await self._queue.put(self._abort) + + +class TaskCompletionWraper: + """Task completion class""" + def __init__(self, message): + self._message = message + self._complete = asyncio.Event() + self._failed = False + + async def await_completion(self): + await self._complete.wait() + return not self._failed + + def _mark_as_complete(self): + self._complete.set() + + +class BatchCompletionWrapper: + """Batch completion class""" + def __init__(self, tasks): + self._tasks = tasks + + async def await_completion(self): + await asyncio.gather(*[task.await_completion() for task in self._tasks]) + return not any(task._failed for task in self._tasks) diff --git a/splitio/util/__init__.py b/splitio/util/__init__.py index 1f2f9e4a..e69de29b 100644 --- a/splitio/util/__init__.py +++ b/splitio/util/__init__.py @@ -1,24 +0,0 @@ -"""Utilities.""" -from datetime import datetime - - -EPOCH_DATETIME = datetime(1970, 1, 1) - -def utctime(): - """ - Return the utc time in nanoseconds. - - :returns: utc time in nanoseconds. - :rtype: float - """ - return (datetime.utcnow() - EPOCH_DATETIME).total_seconds() - - -def utctime_ms(): - """ - Return the utc time in milliseconds. - - :returns: utc time in milliseconds. - :rtype: int - """ - return int(utctime() * 1000) diff --git a/splitio/util/storage_helper.py b/splitio/util/storage_helper.py new file mode 100644 index 00000000..81fdef65 --- /dev/null +++ b/splitio/util/storage_helper.py @@ -0,0 +1,200 @@ +"""Storage Helper.""" +import logging +from splitio.models import splits +from splitio.models import rule_based_segments + +_LOGGER = logging.getLogger(__name__) + +def update_feature_flag_storage(feature_flag_storage, feature_flags, change_number, clear_storage=False): + """ + Update feature flag storage from given list of feature flags while checking the flag set logic + + :param feature_flag_storage: Feature flag storage instance + :type feature_flag_storage: splitio.storage.inmemory.InMemorySplitStorage + :param feature_flag: Feature flag instance to validate. + :type feature_flag: splitio.models.splits.Split + :param: last change number + :type: int + + :return: segments list from feature flags list + :rtype: list(str) + """ + segment_list = set() + to_add = [] + to_delete = [] + if clear_storage: + feature_flag_storage.clear() + + for feature_flag in feature_flags: + if feature_flag_storage.flag_set_filter.intersect(feature_flag.sets) and feature_flag.status == splits.Status.ACTIVE: + to_add.append(feature_flag) + segment_list.update(set(feature_flag.get_segment_names())) + else: + if feature_flag_storage.get(feature_flag.name) is not None: + to_delete.append(feature_flag.name) + + feature_flag_storage.update(to_add, to_delete, change_number) + return segment_list + +def update_rule_based_segment_storage(rule_based_segment_storage, rule_based_segments, change_number, clear_storage=False): + """ + Update rule based segment storage from given list of rule based segments + + :param rule_based_segment_storage: rule based segment storage instance + :type rule_based_segment_storage: splitio.storage.RuleBasedSegmentStorage + :param rule_based_segments: rule based segment instance to validate. + :type rule_based_segments: splitio.models.rule_based_segments.RuleBasedSegment + :param: last change number + :type: int + + :return: segments list from excluded segments list + :rtype: list(str) + """ + if clear_storage: + rule_based_segment_storage.clear() + + segment_list = set() + to_add = [] + to_delete = [] + for rule_based_segment in rule_based_segments: + if rule_based_segment.status == splits.Status.ACTIVE: + to_add.append(rule_based_segment) + segment_list.update(set(rule_based_segment.excluded.get_excluded_standard_segments())) + segment_list.update(rule_based_segment.get_condition_segment_names()) + else: + if rule_based_segment_storage.get(rule_based_segment.name) is not None: + to_delete.append(rule_based_segment.name) + + rule_based_segment_storage.update(to_add, to_delete, change_number) + return segment_list + +def get_standard_segment_names_in_rbs_storage(rule_based_segment_storage): + """ + Retrieve a list of all standard segments names. + + :return: Set of segment names. + :rtype: Set(str) + """ + segment_list = set() + for rb_segment in rule_based_segment_storage.get_segment_names(): + rb_segment_obj = rule_based_segment_storage.get(rb_segment) + segment_list.update(set(rb_segment_obj.excluded.get_excluded_standard_segments())) + segment_list.update(rb_segment_obj.get_condition_segment_names()) + + return segment_list + +async def update_feature_flag_storage_async(feature_flag_storage, feature_flags, change_number, clear_storage=False): + """ + Update feature flag storage from given list of feature flags while checking the flag set logic + + :param feature_flag_storage: Feature flag storage instance + :type feature_flag_storage: splitio.storage.inmemory.InMemorySplitStorage + :param feature_flag: Feature flag instance to validate. + :type feature_flag: splitio.models.splits.Split + :param: last change number + :type: int + + :return: segments list from feature flags list + :rtype: list(str) + """ + if clear_storage: + await feature_flag_storage.clear() + + segment_list = set() + to_add = [] + to_delete = [] + for feature_flag in feature_flags: + if feature_flag_storage.flag_set_filter.intersect(feature_flag.sets) and feature_flag.status == splits.Status.ACTIVE: + to_add.append(feature_flag) + segment_list.update(set(feature_flag.get_segment_names())) + else: + if await feature_flag_storage.get(feature_flag.name) is not None: + to_delete.append(feature_flag.name) + + await feature_flag_storage.update(to_add, to_delete, change_number) + return segment_list + +async def update_rule_based_segment_storage_async(rule_based_segment_storage, rule_based_segments, change_number, clear_storage=False): + """ + Update rule based segment storage from given list of rule based segments + + :param rule_based_segment_storage: rule based segment storage instance + :type rule_based_segment_storage: splitio.storage.RuleBasedSegmentStorage + :param rule_based_segments: rule based segment instance to validate. + :type rule_based_segments: splitio.models.rule_based_segments.RuleBasedSegment + :param: last change number + :type: int + + :return: segments list from excluded segments list + :rtype: list(str) + """ + if clear_storage: + await rule_based_segment_storage.clear() + + segment_list = set() + to_add = [] + to_delete = [] + for rule_based_segment in rule_based_segments: + if rule_based_segment.status == splits.Status.ACTIVE: + to_add.append(rule_based_segment) + segment_list.update(set(rule_based_segment.excluded.get_excluded_standard_segments())) + segment_list.update(rule_based_segment.get_condition_segment_names()) + else: + if await rule_based_segment_storage.get(rule_based_segment.name) is not None: + to_delete.append(rule_based_segment.name) + + await rule_based_segment_storage.update(to_add, to_delete, change_number) + return segment_list + +async def get_standard_segment_names_in_rbs_storage_async(rule_based_segment_storage): + """ + Retrieve a list of all standard segments names. + + :return: Set of segment names. + :rtype: Set(str) + """ + segment_list = set() + segment_names = await rule_based_segment_storage.get_segment_names() + for rb_segment in segment_names: + rb_segment_obj = await rule_based_segment_storage.get(rb_segment) + segment_list.update(set(rb_segment_obj.excluded.get_excluded_standard_segments())) + segment_list.update(rb_segment_obj.get_condition_segment_names()) + + return segment_list + +def get_valid_flag_sets(flag_sets, flag_set_filter): + """ + Check each flag set in given array, return it if exist in a given config flag set array, if config array is empty return all + + :param flag_sets: Flag sets array + :type flag_sets: list(str) + :param config_flag_sets: Config flag sets array + :type config_flag_sets: list(str) + + :return: array of flag sets + :rtype: list(str) + """ + sets_to_fetch = [] + for flag_set in flag_sets: + if not flag_set_filter.set_exist(flag_set) and flag_set_filter.should_filter: + _LOGGER.warning("Flag set %s is not part of the configured flag set list, ignoring the request." % (flag_set)) + continue + sets_to_fetch.append(flag_set) + + return sets_to_fetch + +def combine_valid_flag_sets(result_sets): + """ + Check each flag set in given array of sets, combine all flag sets in one unique set + + :param result_sets: Flag sets set + :type flag_sets: list(set) + + :return: flag sets set + :rtype: set + """ + to_return = set() + for result_set in result_sets: + if isinstance(result_set, set) and len(result_set) > 0: + to_return.update(result_set) + return to_return \ No newline at end of file diff --git a/splitio/util/time.py b/splitio/util/time.py new file mode 100644 index 00000000..62743327 --- /dev/null +++ b/splitio/util/time.py @@ -0,0 +1,33 @@ +"""Utilities.""" +from datetime import datetime +import time + +EPOCH_DATETIME = datetime(1970, 1, 1) + +def utctime(): + """ + Return the utc time in nanoseconds. + + :returns: utc time in nanoseconds. + :rtype: float + """ + return (datetime.utcnow() - EPOCH_DATETIME).total_seconds() + + +def utctime_ms(): + """ + Return the utc time in milliseconds. + + :returns: utc time in milliseconds. + :rtype: int + """ + return int(utctime() * 1000) + +def get_current_epoch_time_ms(): + """ + Get current epoch time in milliseconds + + :return: epoch time + :rtype: int + """ + return int(round(time.time() * 1000)) \ No newline at end of file diff --git a/splitio/version.py b/splitio/version.py index d0c18ecd..4f40eda2 100644 --- a/splitio/version.py +++ b/splitio/version.py @@ -1 +1 @@ -__version__ = '9.1.2' +__version__ = '10.6.0' \ No newline at end of file diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 6bcf261f..175977a2 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -1,12 +1,14 @@ """Split API tests module.""" import pytest +import unittest.mock as mock from splitio.api import auth, client, APIException from splitio.client.util import get_metadata from splitio.client.config import DEFAULT_CONFIG from splitio.version import __version__ - +from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemoryTelemetryStorageAsync class AuthAPITests(object): """Auth API test cases.""" @@ -19,17 +21,20 @@ def test_auth(self, mocker): cfg = DEFAULT_CONFIG.copy() cfg.update({'IPAddressesEnabled': True, 'machineName': 'some_machine_name', 'machineIp': '123.123.123.123'}) sdk_metadata = get_metadata(cfg) - httpclient.get.return_value = client.HttpResponse(200, payload) - auth_api = auth.AuthAPI(httpclient, 'some_api_key', sdk_metadata) + httpclient.get.return_value = client.HttpResponse(200, payload, {}) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + auth_api = auth.AuthAPI(httpclient, 'some_api_key', sdk_metadata, telemetry_runtime_producer) response = auth_api.authenticate() assert response.push_enabled == True assert response.token == token - + call_made = httpclient.get.mock_calls[0] # validate positional arguments - assert call_made[1] == ('auth', '/v2/auth', 'some_api_key') + assert call_made[1] == ('auth', 'v2/auth?s=1.3', 'some_api_key') # validate key-value args (headers) assert call_made[2]['extra_headers'] == { @@ -46,3 +51,58 @@ def raise_exception(*args, **kwargs): response = auth_api.authenticate() assert exc_info.type == APIException assert exc_info.value.message == 'some_message' + + +class AuthAPIAsyncTests(object): + """Auth async API test cases.""" + + @pytest.mark.asyncio + async def test_auth(self, mocker): + """Test auth API call.""" + self.token = "eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk56TTJNREk1TXpjMF9NVGd5TlRnMU1UZ3dOZz09X3NlZ21lbnRzXCI6W1wic3Vic2NyaWJlXCJdLFwiTnpNMk1ESTVNemMwX01UZ3lOVGcxTVRnd05nPT1fc3BsaXRzXCI6W1wic3Vic2NyaWJlXCJdLFwiY29udHJvbF9wcmlcIjpbXCJzdWJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRhdGE6cHVibGlzaGVyc1wiXSxcImNvbnRyb2xfc2VjXCI6W1wic3Vic2NyaWJlXCIsXCJjaGFubmVsLW1ldGFkYXRhOnB1Ymxpc2hlcnNcIl19IiwieC1hYmx5LWNsaWVudElkIjoiY2xpZW50SWQiLCJleHAiOjE2MDIwODgxMjcsImlhdCI6MTYwMjA4NDUyN30.5_MjWonhs6yoFhw44hNJm3H7_YMjXpSW105DwjjppqE" + httpclient = mocker.Mock(spec=client.HttpClientAsync) + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': True, 'machineName': 'some_machine_name', 'machineIp': '123.123.123.123'}) + sdk_metadata = get_metadata(cfg) + telemetry_storage = InMemoryTelemetryStorageAsync() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + auth_api = auth.AuthAPIAsync(httpclient, 'some_api_key', sdk_metadata, telemetry_runtime_producer) + + self.verb = None + self.url = None + self.key = None + self.headers = None + async def get(verb, url, key, extra_headers): + self.url = url + self.verb = verb + self.key = key + self.headers = extra_headers + payload = '{{"pushEnabled": true, "token": "{token}"}}'.format(token=self.token) + return client.HttpResponse(200, payload, {}) + + httpclient.get = get + + response = await auth_api.authenticate() + assert response.push_enabled == True + assert response.token == self.token + + # validate positional arguments + assert self.verb == 'auth' + assert self.url == 'v2/auth?s=1.3' + assert self.key == 'some_api_key' + assert self.headers == { + 'SplitSDKVersion': 'python-%s' % __version__, + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name' + } + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message') + + httpclient.get = raise_exception + with pytest.raises(APIException) as exc_info: + response = await auth_api.authenticate() + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' diff --git a/tests/api/test_events.py b/tests/api/test_events.py index bfc6177b..07fe9473 100644 --- a/tests/api/test_events.py +++ b/tests/api/test_events.py @@ -1,11 +1,15 @@ """Impressions API tests module.""" import pytest +import unittest.mock as mock + from splitio.api import events, client, APIException from splitio.models.events import Event from splitio.client.util import get_metadata from splitio.client.config import DEFAULT_CONFIG from splitio.version import __version__ +from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemoryTelemetryStorageAsync class EventsAPITests(object): @@ -26,17 +30,20 @@ class EventsAPITests(object): def test_post_events(self, mocker): """Test impressions posting API call.""" httpclient = mocker.Mock(spec=client.HttpClient) - httpclient.post.return_value = client.HttpResponse(200, '') + httpclient.post.return_value = client.HttpResponse(200, '', {}) cfg = DEFAULT_CONFIG.copy() cfg.update({'IPAddressesEnabled': True, 'machineName': 'some_machine_name', 'machineIp': '123.123.123.123'}) sdk_metadata = get_metadata(cfg) - events_api = events.EventsAPI(httpclient, 'some_api_key', sdk_metadata) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + events_api = events.EventsAPI(httpclient, 'some_api_key', sdk_metadata, telemetry_runtime_producer) response = events_api.flush_events(self.events) call_made = httpclient.post.mock_calls[0] # validate positional arguments - assert call_made[1] == ('events', '/events/bulk', 'some_api_key') + assert call_made[1] == ('events', 'events/bulk', 'some_api_key') # validate key-value args (headers) assert call_made[2]['extra_headers'] == { @@ -60,17 +67,17 @@ def raise_exception(*args, **kwargs): def test_post_events_ip_address_disabled(self, mocker): """Test impressions posting API call.""" httpclient = mocker.Mock(spec=client.HttpClient) - httpclient.post.return_value = client.HttpResponse(200, '') + httpclient.post.return_value = client.HttpResponse(200, '', {}) cfg = DEFAULT_CONFIG.copy() cfg.update({'IPAddressesEnabled': False}) sdk_metadata = get_metadata(cfg) - events_api = events.EventsAPI(httpclient, 'some_api_key', sdk_metadata) + events_api = events.EventsAPI(httpclient, 'some_api_key', sdk_metadata, mocker.Mock()) response = events_api.flush_events(self.events) call_made = httpclient.post.mock_calls[0] # validate positional arguments - assert call_made[1] == ('events', '/events/bulk', 'some_api_key') + assert call_made[1] == ('events', 'events/bulk', 'some_api_key') # validate key-value args (headers) assert call_made[2]['extra_headers'] == { @@ -79,3 +86,108 @@ def test_post_events_ip_address_disabled(self, mocker): # validate key-value args (body) assert call_made[2]['body'] == self.eventsExpected + + +class EventsAPIAsyncTests(object): + """Impressions Async API test cases.""" + events = [ + Event('k1', 'user', 'purchase', 12.50, 123456, None), + Event('k2', 'user', 'purchase', 12.50, 123456, None), + Event('k3', 'user', 'purchase', None, 123456, {"test": 1234}), + Event('k4', 'user', 'purchase', None, 123456, None) + ] + eventsExpected = [ + {'key': 'k1', 'trafficTypeName': 'user', 'eventTypeId': 'purchase', 'value': 12.50, 'timestamp': 123456, 'properties': None}, + {'key': 'k2', 'trafficTypeName': 'user', 'eventTypeId': 'purchase', 'value': 12.50, 'timestamp': 123456, 'properties': None}, + {'key': 'k3', 'trafficTypeName': 'user', 'eventTypeId': 'purchase', 'value': None, 'timestamp': 123456, 'properties': {"test": 1234}}, + {'key': 'k4', 'trafficTypeName': 'user', 'eventTypeId': 'purchase', 'value': None, 'timestamp': 123456, 'properties': None}, + ] + + @pytest.mark.asyncio + async def test_post_events(self, mocker): + """Test impressions posting API call.""" + httpclient = mocker.Mock(spec=client.HttpClientAsync) + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': True, 'machineName': 'some_machine_name', 'machineIp': '123.123.123.123'}) + sdk_metadata = get_metadata(cfg) + telemetry_storage = InMemoryTelemetryStorageAsync() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + events_api = events.EventsAPIAsync(httpclient, 'some_api_key', sdk_metadata, telemetry_runtime_producer) + + self.verb = None + self.url = None + self.key = None + self.headers = None + self.body = None + async def post(verb, url, key, body, extra_headers): + self.url = url + self.verb = verb + self.key = key + self.headers = extra_headers + self.body = body + return client.HttpResponse(200, '', {}) + httpclient.post = post + + response = await events_api.flush_events(self.events) + # validate positional arguments + assert self.verb == 'events' + assert self.url == 'events/bulk' + assert self.key == 'some_api_key' + + # validate key-value args (headers) + assert self.headers == { + 'SplitSDKVersion': 'python-%s' % __version__, + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name' + } + + # validate key-value args (body) + assert self.body == self.eventsExpected + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message') + httpclient.post = raise_exception + with pytest.raises(APIException) as exc_info: + response = await events_api.flush_events(self.events) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' + + @pytest.mark.asyncio + async def test_post_events_ip_address_disabled(self, mocker): + """Test impressions posting API call.""" + httpclient = mocker.Mock(spec=client.HttpClientAsync) + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': False}) + sdk_metadata = get_metadata(cfg) + events_api = events.EventsAPIAsync(httpclient, 'some_api_key', sdk_metadata, mocker.Mock()) + + self.verb = None + self.url = None + self.key = None + self.headers = None + self.body = None + async def post(verb, url, key, body, extra_headers): + self.url = url + self.verb = verb + self.key = key + self.headers = extra_headers + self.body = body + return client.HttpResponse(200, '', {}) + httpclient.post = post + + response = await events_api.flush_events(self.events) + + # validate positional arguments + assert self.verb == 'events' + assert self.url == 'events/bulk' + assert self.key == 'some_api_key' + + # validate key-value args (headers) + assert self.headers == { + 'SplitSDKVersion': 'python-%s' % __version__, + } + + # validate key-value args (body) + assert self.body == self.eventsExpected diff --git a/tests/api/test_httpclient.py b/tests/api/test_httpclient.py index 694c9a22..837997aa 100644 --- a/tests/api/test_httpclient.py +++ b/tests/api/test_httpclient.py @@ -1,6 +1,13 @@ """HTTPClient test module.""" +from requests_kerberos import HTTPKerberosAuth +import pytest +import unittest.mock as mock +import requests +from splitio.client.config import AuthenticateScheme from splitio.api import client +from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemoryTelemetryStorageAsync class HttpClientTests(object): """Http Client test cases.""" @@ -9,14 +16,16 @@ def test_get(self, mocker): """Test HTTP GET verb requests.""" response_mock = mocker.Mock() response_mock.status_code = 200 + response_mock.headers = {} response_mock.text = 'ok' get_mock = mocker.Mock() get_mock.return_value = response_mock mocker.patch('splitio.api.client.requests.get', new=get_mock) httpclient = client.HttpClient() - response = httpclient.get('sdk', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + httpclient.set_telemetry_data("metric", mocker.Mock()) + response = httpclient.get('sdk', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) call = mocker.call( - client.HttpClient.SDK_URL + '/test1', + client.SDK_URL + '/test1', headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, params={'param1': 123}, timeout=None @@ -26,9 +35,9 @@ def test_get(self, mocker): assert get_mock.mock_calls == [call] get_mock.reset_mock() - response = httpclient.get('events', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + response = httpclient.get('events', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) call = mocker.call( - client.HttpClient.EVENTS_URL + '/test1', + client.EVENTS_URL + '/test1', headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, params={'param1': 123}, timeout=None @@ -41,12 +50,14 @@ def test_get_custom_urls(self, mocker): """Test HTTP GET verb requests.""" response_mock = mocker.Mock() response_mock.status_code = 200 + response_mock.headers = {} response_mock.text = 'ok' get_mock = mocker.Mock() get_mock.return_value = response_mock mocker.patch('splitio.api.client.requests.get', new=get_mock) httpclient = client.HttpClient(sdk_url='https://sdk.com', events_url='https://events.com') - response = httpclient.get('sdk', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + httpclient.set_telemetry_data("metric", mocker.Mock()) + response = httpclient.get('sdk', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) call = mocker.call( 'https://sdk.com/test1', headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, @@ -58,7 +69,7 @@ def test_get_custom_urls(self, mocker): assert response.body == 'ok' get_mock.reset_mock() - response = httpclient.get('events', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + response = httpclient.get('events', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) call = mocker.call( 'https://events.com/test1', headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, @@ -74,14 +85,16 @@ def test_post(self, mocker): """Test HTTP GET verb requests.""" response_mock = mocker.Mock() response_mock.status_code = 200 + response_mock.headers = {} response_mock.text = 'ok' get_mock = mocker.Mock() get_mock.return_value = response_mock mocker.patch('splitio.api.client.requests.post', new=get_mock) httpclient = client.HttpClient() - response = httpclient.post('sdk', '/test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + httpclient.set_telemetry_data("metric", mocker.Mock()) + response = httpclient.post('sdk', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) call = mocker.call( - client.HttpClient.SDK_URL + '/test1', + client.SDK_URL + '/test1', json={'p1': 'a'}, headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, params={'param1': 123}, @@ -92,9 +105,9 @@ def test_post(self, mocker): assert get_mock.mock_calls == [call] get_mock.reset_mock() - response = httpclient.post('events', '/test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + response = httpclient.post('events', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) call = mocker.call( - client.HttpClient.EVENTS_URL + '/test1', + client.EVENTS_URL + '/test1', json={'p1': 'a'}, headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, params={'param1': 123}, @@ -108,12 +121,14 @@ def test_post_custom_urls(self, mocker): """Test HTTP GET verb requests.""" response_mock = mocker.Mock() response_mock.status_code = 200 + response_mock.headers = {} response_mock.text = 'ok' get_mock = mocker.Mock() get_mock.return_value = response_mock mocker.patch('splitio.api.client.requests.post', new=get_mock) httpclient = client.HttpClient(sdk_url='https://sdk.com', events_url='https://events.com') - response = httpclient.post('sdk', '/test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + httpclient.set_telemetry_data("metric", mocker.Mock()) + response = httpclient.post('sdk', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) call = mocker.call( 'https://sdk.com' + '/test1', json={'p1': 'a'}, @@ -126,7 +141,7 @@ def test_post_custom_urls(self, mocker): assert get_mock.mock_calls == [call] get_mock.reset_mock() - response = httpclient.post('events', '/test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + response = httpclient.post('events', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) call = mocker.call( 'https://events.com' + '/test1', json={'p1': 'a'}, @@ -137,3 +152,584 @@ def test_post_custom_urls(self, mocker): assert response.status_code == 200 assert response.body == 'ok' assert get_mock.mock_calls == [call] + + def test_telemetry(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + + response_mock = mocker.Mock() + response_mock.status_code = 200 + response_mock.headers = {} + response_mock.text = 'ok' + get_mock = mocker.Mock() + get_mock.return_value = response_mock + mocker.patch('splitio.api.client.requests.post', new=get_mock) + httpclient = client.HttpClient(timeout=1500, sdk_url='https://sdk.com', events_url='https://events.com') + httpclient.set_telemetry_data("metric", telemetry_runtime_producer) + + self.metric1 = None + self.cur_time = 0 + def record_successful_sync(metric_name, cur_time): + self.metric1 = metric_name + self.cur_time = cur_time + httpclient._telemetry_runtime_producer.record_successful_sync = record_successful_sync + + self.metric2 = None + self.elapsed = 0 + def record_sync_latency(metric_name, elapsed): + self.metric2 = metric_name + self.elapsed = elapsed + httpclient._telemetry_runtime_producer.record_sync_latency = record_sync_latency + + self.metric3 = None + self.status = 0 + def record_sync_error(metric_name, elapsed): + self.metric3 = metric_name + self.status = elapsed + httpclient._telemetry_runtime_producer.record_sync_error = record_sync_error + + httpclient.post('sdk', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + assert (self.metric2 == "metric") + assert (self.metric1 == "metric") + assert (self.cur_time > self.elapsed) + + response_mock.status_code = 400 + response_mock.headers = {} + response_mock.text = 'ok' + httpclient.post('sdk', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + assert (self.metric3 == "metric") + assert (self.status == 400) + + # testing get call + mocker.patch('splitio.api.client.requests.get', new=get_mock) + self.metric1 = None + self.cur_time = 0 + self.metric2 = None + self.elapsed = 0 + response_mock.status_code = 200 + httpclient.get('sdk', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + assert (self.metric2 == "metric") + assert (self.metric1 == "metric") + assert (self.cur_time > self.elapsed) + + self.metric3 = None + self.status = 0 + response_mock.status_code = 400 + httpclient.get('sdk', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + assert (self.metric3 == "metric") + assert (self.status == 400) + +class HttpClientKerberosTests(object): + """Http Client test cases.""" + + def test_authentication_scheme(self, mocker): + global turl + global theaders + global tparams + global ttimeout + global tjson + + turl = None + theaders = None + tparams = None + ttimeout = None + class get_mock(object): + def __init__(self, url, headers, params, timeout): + global turl + global theaders + global tparams + global ttimeout + turl = url + theaders = headers + tparams = params + ttimeout = timeout + + def __enter__(self): + response_mock = mocker.Mock() + response_mock.status_code = 200 + response_mock.text = 'ok' + return response_mock + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + mocker.patch('splitio.api.client.requests.Session.get', new=get_mock) + httpclient = client.HttpClientKerberos(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_SPNEGO, authentication_params=[None, None]) + httpclient.set_telemetry_data("metric", mocker.Mock()) + response = httpclient.get('sdk', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + assert turl == 'https://sdk.com/test1' + assert theaders == {'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'} + assert tparams == {'param1': 123} + assert ttimeout == None + assert response.status_code == 200 + assert response.body == 'ok' + + turl = None + theaders = None + tparams = None + ttimeout = None + httpclient = client.HttpClientKerberos(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_SPNEGO, authentication_params=['bilal', 'split']) + httpclient.set_telemetry_data("metric", mocker.Mock()) + response = httpclient.get('sdk', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + assert turl == 'https://sdk.com/test1' + assert theaders == {'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'} + assert tparams == {'param1': 123} + assert ttimeout == None + + assert response.status_code == 200 + assert response.body == 'ok' + + response_mock = mocker.Mock() + response_mock.status_code = 200 + response_mock.headers = {} + response_mock.text = 'ok' + + turl = None + theaders = None + tparams = None + ttimeout = None + tjson = None + class post_mock(object): + def __init__(self, url, params, headers, json, timeout): + global turl + global theaders + global tparams + global ttimeout + global tjson + turl = url + theaders = headers + tparams = params + ttimeout = timeout + tjson = json + + def __enter__(self): + response_mock = mocker.Mock() + response_mock.status_code = 200 + response_mock.text = 'ok' + return response_mock + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + mocker.patch('splitio.api.client.requests.Session.post', new=post_mock) + + httpclient = client.HttpClientKerberos(timeout=1500, sdk_url='https://sdk.com', events_url='https://events.com', authentication_scheme=AuthenticateScheme.KERBEROS_PROXY, authentication_params=[None, None]) + httpclient.set_telemetry_data("metric", mocker.Mock()) + + response = httpclient.post('events', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + assert turl == 'https://events.com/test1' + assert tjson == {'p1': 'a'} + assert theaders == {'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'} + assert tparams == {'param1': 123} + assert ttimeout == 1.5 + + assert response.status_code == 200 + assert response.body == 'ok' + + turl = None + theaders = None + tparams = None + ttimeout = None + mocker.patch('splitio.api.client.requests.Session.get', new=get_mock) + httpclient = client.HttpClientKerberos(timeout=1500, sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_PROXY, authentication_params=['bilal', 'split']) + httpclient.set_telemetry_data("metric", mocker.Mock()) + response = httpclient.get('sdk', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + assert turl == 'https://sdk.com/test1' + assert theaders == {'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'} + assert tparams == {'param1': 123} + assert ttimeout == 1.5 + + assert response.status_code == 200 + assert response.body == 'ok' + + # test auth settings + httpclient = client.HttpClientKerberos(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_SPNEGO, authentication_params=['bilal', 'split']) + httpclient._set_authentication('sdk') + for server in ['sdk', 'events', 'auth', 'telemetry']: + assert(httpclient._sessions[server].auth.principal == 'bilal') + assert(httpclient._sessions[server].auth.password == 'split') + assert(isinstance(httpclient._sessions[server].auth, HTTPKerberosAuth)) + + httpclient._sessions['sdk'].close() + httpclient._sessions['events'].close() + httpclient._sessions['sdk'] = requests.Session() + httpclient._sessions['events'] = requests.Session() + assert(httpclient._sessions['sdk'].auth == None) + assert(httpclient._sessions['events'].auth == None) + + httpclient._set_authentication('sdk') + assert(httpclient._sessions['sdk'].auth.principal == 'bilal') + assert(httpclient._sessions['sdk'].auth.password == 'split') + assert(isinstance(httpclient._sessions['sdk'].auth, HTTPKerberosAuth)) + assert(httpclient._sessions['events'].auth == None) + + httpclient2 = client.HttpClientKerberos(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_SPNEGO, authentication_params=[None, None]) + for server in ['sdk', 'events', 'auth', 'telemetry']: + assert(httpclient2._sessions[server].auth.principal == None) + assert(httpclient2._sessions[server].auth.password == None) + assert(isinstance(httpclient2._sessions[server].auth, HTTPKerberosAuth)) + + httpclient3 = client.HttpClientKerberos(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_PROXY, authentication_params=['bilal', 'split']) + for server in ['sdk', 'events', 'auth', 'telemetry']: + assert(httpclient3._sessions[server].adapters['https://']._principal == 'bilal') + assert(httpclient3._sessions[server].adapters['https://']._password == 'split') + assert(isinstance(httpclient3._sessions[server].adapters['https://'], client.HTTPAdapterWithProxyKerberosAuth)) + + httpclient4 = client.HttpClientKerberos(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_PROXY, authentication_params=[None, None]) + for server in ['sdk', 'events', 'auth', 'telemetry']: + assert(httpclient4._sessions[server].adapters['https://']._principal == None) + assert(httpclient4._sessions[server].adapters['https://']._password == None) + assert(isinstance(httpclient4._sessions[server].adapters['https://'], client.HTTPAdapterWithProxyKerberosAuth)) + + def test_proxy_exception(self, mocker): + global count + count = 0 + class get_mock(object): + def __init__(self, url, params, headers, timeout): + pass + + def __enter__(self): + global count + count += 1 + if count == 1: + raise requests.exceptions.ProxyError() + + response_mock = mocker.Mock() + response_mock.status_code = 200 + response_mock.text = 'ok' + return response_mock + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + mocker.patch('splitio.api.client.requests.Session.get', new=get_mock) + httpclient = client.HttpClientKerberos(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS_SPNEGO, authentication_params=[None, None]) + httpclient.set_telemetry_data("metric", mocker.Mock()) + response = httpclient.get('sdk', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + assert response.status_code == 200 + assert response.body == 'ok' + + count = 0 + class post_mock(object): + def __init__(self, url, params, headers, json, timeout): + pass + + def __enter__(self): + global count + count += 1 + if count == 1: + raise requests.exceptions.ProxyError() + + response_mock = mocker.Mock() + response_mock.status_code = 200 + response_mock.text = 'ok' + return response_mock + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + mocker.patch('splitio.api.client.requests.Session.post', new=post_mock) + + httpclient = client.HttpClientKerberos(timeout=1500, sdk_url='https://sdk.com', events_url='https://events.com', authentication_scheme=AuthenticateScheme.KERBEROS_PROXY, authentication_params=[None, None]) + httpclient.set_telemetry_data("metric", mocker.Mock()) + response = httpclient.post('events', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + assert response.status_code == 200 + assert response.body == 'ok' + + + + def test_telemetry(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + + response_mock = mocker.Mock() + response_mock.status_code = 200 + response_mock.headers = {} + response_mock.text = 'ok' + get_mock = mocker.Mock() + get_mock.return_value = response_mock + mocker.patch('splitio.api.client.requests.post', new=get_mock) + httpclient = client.HttpClient(timeout=1500, sdk_url='https://sdk.com', events_url='https://events.com') + httpclient.set_telemetry_data("metric", telemetry_runtime_producer) + + self.metric1 = None + self.cur_time = 0 + def record_successful_sync(metric_name, cur_time): + self.metric1 = metric_name + self.cur_time = cur_time + httpclient._telemetry_runtime_producer.record_successful_sync = record_successful_sync + + self.metric2 = None + self.elapsed = 0 + def record_sync_latency(metric_name, elapsed): + self.metric2 = metric_name + self.elapsed = elapsed + httpclient._telemetry_runtime_producer.record_sync_latency = record_sync_latency + + self.metric3 = None + self.status = 0 + def record_sync_error(metric_name, elapsed): + self.metric3 = metric_name + self.status = elapsed + httpclient._telemetry_runtime_producer.record_sync_error = record_sync_error + + httpclient.post('sdk', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + assert (self.metric2 == "metric") + assert (self.metric1 == "metric") + assert (self.cur_time > self.elapsed) + + response_mock.status_code = 400 + response_mock.headers = {} + response_mock.text = 'ok' + httpclient.post('sdk', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + assert (self.metric3 == "metric") + assert (self.status == 400) + + # testing get call + mocker.patch('splitio.api.client.requests.get', new=get_mock) + self.metric1 = None + self.cur_time = 0 + self.metric2 = None + self.elapsed = 0 + response_mock.status_code = 200 + httpclient.get('sdk', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + assert (self.metric2 == "metric") + assert (self.metric1 == "metric") + assert (self.cur_time > self.elapsed) + + self.metric3 = None + self.status = 0 + response_mock.status_code = 400 + httpclient.get('sdk', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + assert (self.metric3 == "metric") + assert (self.status == 400) + +class MockResponse: + def __init__(self, text, status, headers): + self._text = text + self.status = status + self.headers = headers + + async def text(self): + return self._text + + async def __aexit__(self, exc_type, exc, tb): + pass + + async def __aenter__(self): + return self + +class HttpClientAsyncTests(object): + """Http Client test cases.""" + + @pytest.mark.asyncio + async def test_get(self, mocker): + """Test HTTP GET verb requests.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + response_mock = MockResponse('ok', 200, {}) + get_mock = mocker.Mock() + get_mock.return_value = response_mock + mocker.patch('splitio.optional.loaders.aiohttp.ClientSession.get', new=get_mock) + httpclient = client.HttpClientAsync() + httpclient.set_telemetry_data("metric", telemetry_runtime_producer) + response = await httpclient.get('sdk', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + assert response.status_code == 200 + assert response.body == 'ok' + call = mocker.call( + client.SDK_URL + '/test1', + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None + ) + assert get_mock.mock_calls == [call] + get_mock.reset_mock() + + response = await httpclient.get('events', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + client.EVENTS_URL + '/test1', + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None + ) + assert get_mock.mock_calls == [call] + assert response.status_code == 200 + assert response.body == 'ok' + + @pytest.mark.asyncio + async def test_get_custom_urls(self, mocker): + """Test HTTP GET verb requests.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + response_mock = MockResponse('ok', 200, {}) + get_mock = mocker.Mock() + get_mock.return_value = response_mock + mocker.patch('splitio.optional.loaders.aiohttp.ClientSession.get', new=get_mock) + httpclient = client.HttpClientAsync(sdk_url='https://sdk.com', events_url='https://events.com') + httpclient.set_telemetry_data("metric", telemetry_runtime_producer) + response = await httpclient.get('sdk', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + 'https://sdk.com/test1', + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None + ) + assert get_mock.mock_calls == [call] + assert response.status_code == 200 + assert response.body == 'ok' + get_mock.reset_mock() + + response = await httpclient.get('events', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + 'https://events.com/test1', + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None + ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] + + @pytest.mark.asyncio + async def test_post(self, mocker): + """Test HTTP POST verb requests.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + response_mock = MockResponse('ok', 200, {}) + get_mock = mocker.Mock() + get_mock.return_value = response_mock + mocker.patch('splitio.optional.loaders.aiohttp.ClientSession.post', new=get_mock) + httpclient = client.HttpClientAsync() + httpclient.set_telemetry_data("metric", telemetry_runtime_producer) + response = await httpclient.post('sdk', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + client.SDK_URL + '/test1', + json={"p1": "a"}, + headers={'Content-Type': 'application/json', 'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Accept-Encoding': 'gzip'}, + params={'param1': 123}, + timeout=None + ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] + get_mock.reset_mock() + + response = await httpclient.post('events', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + client.EVENTS_URL + '/test1', + json={'p1': 'a'}, + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json', 'Accept-Encoding': 'gzip'}, + params={'param1': 123}, + timeout=None + ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] + + @pytest.mark.asyncio + async def test_post_custom_urls(self, mocker): + """Test HTTP GET verb requests.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + response_mock = MockResponse('ok', 200, {}) + get_mock = mocker.Mock() + get_mock.return_value = response_mock + mocker.patch('splitio.optional.loaders.aiohttp.ClientSession.post', new=get_mock) + httpclient = client.HttpClientAsync(sdk_url='https://sdk.com', events_url='https://events.com') + httpclient.set_telemetry_data("metric", telemetry_runtime_producer) + response = await httpclient.post('sdk', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + 'https://sdk.com' + '/test1', + json={"p1": "a"}, + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json', 'Accept-Encoding': 'gzip'}, + params={'param1': 123}, + timeout=None + ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] + get_mock.reset_mock() + + response = await httpclient.post('events', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + 'https://events.com' + '/test1', + json={"p1": "a"}, + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json', 'Accept-Encoding': 'gzip'}, + params={'param1': 123}, + timeout=None + ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] + + @pytest.mark.asyncio + async def test_telemetry(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + response_mock = MockResponse('ok', 200, {}) + get_mock = mocker.Mock() + get_mock.return_value = response_mock + mocker.patch('splitio.optional.loaders.aiohttp.ClientSession.post', new=get_mock) + httpclient = client.HttpClientAsync(sdk_url='https://sdk.com', events_url='https://events.com') + httpclient.set_telemetry_data("metric", telemetry_runtime_producer) + + self.metric1 = None + self.cur_time = 0 + async def record_successful_sync(metric_name, cur_time): + self.metric1 = metric_name + self.cur_time = cur_time + httpclient._telemetry_runtime_producer.record_successful_sync = record_successful_sync + + self.metric2 = None + self.elapsed = 0 + async def record_sync_latency(metric_name, elapsed): + self.metric2 = metric_name + self.elapsed = elapsed + httpclient._telemetry_runtime_producer.record_sync_latency = record_sync_latency + + self.metric3 = None + self.status = 0 + async def record_sync_error(metric_name, elapsed): + self.metric3 = metric_name + self.status = elapsed + httpclient._telemetry_runtime_producer.record_sync_error = record_sync_error + + await httpclient.post('events', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + assert (self.metric2 == "metric") + assert (self.metric1 == "metric") + assert (self.cur_time > self.elapsed) + + response_mock = MockResponse('ok', 400, {}) + get_mock = mocker.Mock() + get_mock.return_value = response_mock + mocker.patch('splitio.optional.loaders.aiohttp.ClientSession.post', new=get_mock) + await httpclient.post('sdk', 'test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + assert (self.metric3 == "metric") + assert (self.status == 400) + + # testing get call + response_mock = MockResponse('ok', 200, {}) + get_mock = mocker.Mock() + get_mock.return_value = response_mock + mocker.patch('splitio.optional.loaders.aiohttp.ClientSession.get', new=get_mock) + self.metric1 = None + self.cur_time = 0 + self.metric2 = None + self.elapsed = 0 + await httpclient.get('sdk', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + assert (self.metric2 == "metric") + assert (self.metric1 == "metric") + assert (self.cur_time > self.elapsed) + + self.metric3 = None + self.status = 0 + response_mock = MockResponse('ok', 400, {}) + get_mock.return_value = response_mock + await httpclient.get('sdk', 'test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + assert (self.metric3 == "metric") + assert (self.status == 400) diff --git a/tests/api/test_impressions_api.py b/tests/api/test_impressions_api.py index 54d64b1a..b022a464 100644 --- a/tests/api/test_impressions_api.py +++ b/tests/api/test_impressions_api.py @@ -1,64 +1,72 @@ """Impressions API tests module.""" import pytest +import unittest.mock as mock + from splitio.api import impressions, client, APIException from splitio.models.impressions import Impression -from splitio.engine.impressions import Counter, ImpressionsMode +from splitio.engine.impressions.impressions import ImpressionsMode +from splitio.engine.impressions.manager import Counter from splitio.client.util import get_metadata from splitio.client.config import DEFAULT_CONFIG from splitio.version import __version__ +from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemoryTelemetryStorageAsync - -class ImpressionsAPITests(object): - """Impressions API test cases.""" - impressions = [ - Impression('k1', 'f1', 'on', 'l1', 123456, 'b1', 321654), - Impression('k2', 'f2', 'off', 'l1', 123456, 'b1', 321654), - Impression('k3', 'f1', 'on', 'l1', 123456, 'b1', 321654) +impressions_mock = [ + Impression('k1', 'f1', 'on', 'l1', 123456, 'b1', 321654, None, {'prop': 'val'}), + Impression('k2', 'f2', 'off', 'l1', 123456, 'b1', 321654, None, None), + Impression('k3', 'f1', 'on', 'l1', 123456, 'b1', 321654, None, None) +] +expectedImpressions = [{ + 'f': 'f1', + 'i': [ + {'k': 'k1', 'b': 'b1', 't': 'on', 'r': 'l1', 'm': 321654, 'c': 123456, 'pt': None, 'properties': {"prop": "val"}}, + {'k': 'k3', 'b': 'b1', 't': 'on', 'r': 'l1', 'm': 321654, 'c': 123456, 'pt': None}, + ], +}, { + 'f': 'f2', + 'i': [ + {'k': 'k2', 'b': 'b1', 't': 'off', 'r': 'l1', 'm': 321654, 'c': 123456, 'pt': None}, ] - expectedImpressions = [{ - 'f': 'f1', - 'i': [ - {'k': 'k1', 'b': 'b1', 't': 'on', 'r': 'l1', 'm': 321654, 'c': 123456, 'pt': None}, - {'k': 'k3', 'b': 'b1', 't': 'on', 'r': 'l1', 'm': 321654, 'c': 123456, 'pt': None}, - ], - }, { - 'f': 'f2', - 'i': [ - {'k': 'k2', 'b': 'b1', 't': 'off', 'r': 'l1', 'm': 321654, 'c': 123456, 'pt': None}, - ] - }] - - counters = [ - Counter.CountPerFeature('f1', 123, 2), - Counter.CountPerFeature('f2', 123, 123), - Counter.CountPerFeature('f1', 456, 111), - Counter.CountPerFeature('f2', 456, 222) +}] + +counters = [ + Counter.CountPerFeature('f1', 123, 2), + Counter.CountPerFeature('f2', 123, 123), + Counter.CountPerFeature('f1', 456, 111), + Counter.CountPerFeature('f2', 456, 222) +] + +expected_counters = { + 'pf': [ + {'f': 'f1', 'm': 123, 'rc': 2}, + {'f': 'f2', 'm': 123, 'rc': 123}, + {'f': 'f1', 'm': 456, 'rc': 111}, + {'f': 'f2', 'm': 456, 'rc': 222}, ] +} - expected_counters = { - 'pf': [ - {'f': 'f1', 'm': 123, 'rc': 2}, - {'f': 'f2', 'm': 123, 'rc': 123}, - {'f': 'f1', 'm': 456, 'rc': 111}, - {'f': 'f2', 'm': 456, 'rc': 222}, - ] - } +class ImpressionsAPITests(object): + """Impressions API test cases.""" def test_post_impressions(self, mocker): """Test impressions posting API call.""" httpclient = mocker.Mock(spec=client.HttpClient) - httpclient.post.return_value = client.HttpResponse(200, '') + httpclient.post.return_value = client.HttpResponse(200, '', {}) cfg = DEFAULT_CONFIG.copy() cfg.update({'IPAddressesEnabled': True, 'machineName': 'some_machine_name', 'machineIp': '123.123.123.123'}) sdk_metadata = get_metadata(cfg) - impressions_api = impressions.ImpressionsAPI(httpclient, 'some_api_key', sdk_metadata) - response = impressions_api.flush_impressions(self.impressions) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impressions_api = impressions.ImpressionsAPI(httpclient, 'some_api_key', sdk_metadata, telemetry_runtime_producer) + response = impressions_api.flush_impressions(impressions_mock) call_made = httpclient.post.mock_calls[0] # validate positional arguments - assert call_made[1] == ('events', '/testImpressions/bulk', 'some_api_key') + assert call_made[1] == ('events', 'testImpressions/bulk', 'some_api_key') # validate key-value args (headers) assert call_made[2]['extra_headers'] == { @@ -69,31 +77,31 @@ def test_post_impressions(self, mocker): } # validate key-value args (body) - assert call_made[2]['body'] == self.expectedImpressions + assert call_made[2]['body'] == expectedImpressions httpclient.reset_mock() def raise_exception(*args, **kwargs): raise client.HttpClientException('some_message') httpclient.post.side_effect = raise_exception with pytest.raises(APIException) as exc_info: - response = impressions_api.flush_impressions(self.impressions) + response = impressions_api.flush_impressions(impressions_mock) assert exc_info.type == APIException assert exc_info.value.message == 'some_message' def test_post_impressions_ip_address_disabled(self, mocker): """Test impressions posting API call.""" httpclient = mocker.Mock(spec=client.HttpClient) - httpclient.post.return_value = client.HttpResponse(200, '') + httpclient.post.return_value = client.HttpResponse(200, '', {}) cfg = DEFAULT_CONFIG.copy() cfg.update({'IPAddressesEnabled': False}) sdk_metadata = get_metadata(cfg) - impressions_api = impressions.ImpressionsAPI(httpclient, 'some_api_key', sdk_metadata, ImpressionsMode.DEBUG) - response = impressions_api.flush_impressions(self.impressions) + impressions_api = impressions.ImpressionsAPI(httpclient, 'some_api_key', sdk_metadata, mocker.Mock(), ImpressionsMode.DEBUG) + response = impressions_api.flush_impressions(impressions_mock) call_made = httpclient.post.mock_calls[0] # validate positional arguments - assert call_made[1] == ('events', '/testImpressions/bulk', 'some_api_key') + assert call_made[1] == ('events', 'testImpressions/bulk', 'some_api_key') # validate key-value args (headers) assert call_made[2]['extra_headers'] == { @@ -102,22 +110,22 @@ def test_post_impressions_ip_address_disabled(self, mocker): } # validate key-value args (body) - assert call_made[2]['body'] == self.expectedImpressions + assert call_made[2]['body'] == expectedImpressions def test_post_counters(self, mocker): """Test impressions posting API call.""" httpclient = mocker.Mock(spec=client.HttpClient) - httpclient.post.return_value = client.HttpResponse(200, '') + httpclient.post.return_value = client.HttpResponse(200, '', {}) cfg = DEFAULT_CONFIG.copy() cfg.update({'IPAddressesEnabled': True, 'machineName': 'some_machine_name', 'machineIp': '123.123.123.123'}) sdk_metadata = get_metadata(cfg) - impressions_api = impressions.ImpressionsAPI(httpclient, 'some_api_key', sdk_metadata) - response = impressions_api.flush_counters(self.counters) + impressions_api = impressions.ImpressionsAPI(httpclient, 'some_api_key', sdk_metadata, mocker.Mock()) + response = impressions_api.flush_counters(counters) call_made = httpclient.post.mock_calls[0] # validate positional arguments - assert call_made[1] == ('events', '/testImpressions/count', 'some_api_key') + assert call_made[1] == ('events', 'testImpressions/count', 'some_api_key') # validate key-value args (headers) assert call_made[2]['extra_headers'] == { @@ -128,13 +136,159 @@ def test_post_counters(self, mocker): } # validate key-value args (body) - assert call_made[2]['body'] == self.expected_counters + assert call_made[2]['body'] == expected_counters httpclient.reset_mock() def raise_exception(*args, **kwargs): raise client.HttpClientException('some_message') httpclient.post.side_effect = raise_exception with pytest.raises(APIException) as exc_info: - response = impressions_api.flush_counters(self.counters) + response = impressions_api.flush_counters(counters) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' + + +class ImpressionsAPIAsyncTests(object): + """Impressions API test cases.""" + + @pytest.mark.asyncio + async def test_post_impressions(self, mocker): + """Test impressions posting API call.""" + httpclient = mocker.Mock(spec=client.HttpClientAsync) + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': True, 'machineName': 'some_machine_name', 'machineIp': '123.123.123.123'}) + sdk_metadata = get_metadata(cfg) + telemetry_storage = InMemoryTelemetryStorageAsync() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impressions_api = impressions.ImpressionsAPIAsync(httpclient, 'some_api_key', sdk_metadata, telemetry_runtime_producer) + + self.verb = None + self.url = None + self.key = None + self.headers = None + self.body = None + async def post(verb, url, key, body, extra_headers): + self.url = url + self.verb = verb + self.key = key + self.headers = extra_headers + self.body = body + return client.HttpResponse(200, '', {}) + httpclient.post = post + + response = await impressions_api.flush_impressions(impressions_mock) + + # validate positional arguments + assert self.verb == 'events' + assert self.url == 'testImpressions/bulk' + assert self.key == 'some_api_key' + + # validate key-value args (headers) + assert self.headers == { + 'SplitSDKVersion': 'python-%s' % __version__, + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name', + 'SplitSDKImpressionsMode': 'OPTIMIZED' + } + + # validate key-value args (body) + assert self.body == expectedImpressions + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message') + httpclient.post = raise_exception + with pytest.raises(APIException) as exc_info: + response = await impressions_api.flush_impressions(impressions_mock) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' + + @pytest.mark.asyncio + async def test_post_impressions_ip_address_disabled(self, mocker): + """Test impressions posting API call.""" + httpclient = mocker.Mock(spec=client.HttpClientAsync) + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': False}) + sdk_metadata = get_metadata(cfg) + impressions_api = impressions.ImpressionsAPIAsync(httpclient, 'some_api_key', sdk_metadata, mocker.Mock(), ImpressionsMode.DEBUG) + + self.verb = None + self.url = None + self.key = None + self.headers = None + self.body = None + async def post(verb, url, key, body, extra_headers): + self.url = url + self.verb = verb + self.key = key + self.headers = extra_headers + self.body = body + return client.HttpResponse(200, '', {}) + httpclient.post = post + + response = await impressions_api.flush_impressions(impressions_mock) + + # validate positional arguments + assert self.verb == 'events' + assert self.url == 'testImpressions/bulk' + assert self.key == 'some_api_key' + + # validate key-value args (headers) + assert self.headers == { + 'SplitSDKVersion': 'python-%s' % __version__, + 'SplitSDKImpressionsMode': 'DEBUG' + } + + # validate key-value args (body) + assert self.body == expectedImpressions + + @pytest.mark.asyncio + async def test_post_counters(self, mocker): + """Test impressions posting API call.""" + httpclient = mocker.Mock(spec=client.HttpClientAsync) + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': True, 'machineName': 'some_machine_name', 'machineIp': '123.123.123.123'}) + sdk_metadata = get_metadata(cfg) + impressions_api = impressions.ImpressionsAPIAsync(httpclient, 'some_api_key', sdk_metadata, mocker.Mock()) + + self.verb = None + self.url = None + self.key = None + self.headers = None + self.body = None + async def post(verb, url, key, body, extra_headers): + self.url = url + self.verb = verb + self.key = key + self.headers = extra_headers + self.body = body + return client.HttpResponse(200, '', {}) + httpclient.post = post + + response = await impressions_api.flush_counters(counters) + + # validate positional arguments + assert self.verb == 'events' + assert self.url == 'testImpressions/count' + assert self.key == 'some_api_key' + + # validate key-value args (headers) + assert self.headers == { + 'SplitSDKVersion': 'python-%s' % __version__, + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name', + 'SplitSDKImpressionsMode': 'OPTIMIZED' + } + + # validate key-value args (body) + assert self.body == expected_counters + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message') + httpclient.post = raise_exception + with pytest.raises(APIException) as exc_info: + response = await impressions_api.flush_counters(counters) assert exc_info.type == APIException assert exc_info.value.message == 'some_message' diff --git a/tests/api/test_segments_api.py b/tests/api/test_segments_api.py index 1998469a..8681be59 100644 --- a/tests/api/test_segments_api.py +++ b/tests/api/test_segments_api.py @@ -1,35 +1,35 @@ """Segment API tests module.""" import pytest +import unittest.mock as mock from splitio.api import segments, client, APIException from splitio.api.commons import FetchOptions from splitio.client.util import SdkMetadata - class SegmentAPITests(object): """Segment API test cases.""" def test_fetch_segment_changes(self, mocker): """Test segment changes fetching API call.""" httpclient = mocker.Mock(spec=client.HttpClient) - httpclient.get.return_value = client.HttpResponse(200, '{"prop1": "value1"}') - segment_api = segments.SegmentsAPI(httpclient, 'some_api_key', SdkMetadata('1.0', 'some', '1.2.3.4')) + httpclient.get.return_value = client.HttpResponse(200, '{"prop1": "value1"}', {}) + segment_api = segments.SegmentsAPI(httpclient, 'some_api_key', SdkMetadata('1.0', 'some', '1.2.3.4'), mocker.Mock()) - response = segment_api.fetch_segment('some_segment', 123, FetchOptions()) + response = segment_api.fetch_segment('some_segment', 123, FetchOptions(None, None, None, None, None)) assert response['prop1'] == 'value1' - assert httpclient.get.mock_calls == [mocker.call('sdk', '/segmentChanges/some_segment', 'some_api_key', + assert httpclient.get.mock_calls == [mocker.call('sdk', 'segmentChanges/some_segment', 'some_api_key', extra_headers={ - 'SplitSDKVersion': '1.0', - 'SplitSDKMachineIP': '1.2.3.4', + 'SplitSDKVersion': '1.0', + 'SplitSDKMachineIP': '1.2.3.4', 'SplitSDKMachineName': 'some' }, query={'since': 123})] httpclient.reset_mock() - response = segment_api.fetch_segment('some_segment', 123, FetchOptions(True)) + response = segment_api.fetch_segment('some_segment', 123, FetchOptions(True, None, None, None, None)) assert response['prop1'] == 'value1' - assert httpclient.get.mock_calls == [mocker.call('sdk', '/segmentChanges/some_segment', 'some_api_key', + assert httpclient.get.mock_calls == [mocker.call('sdk', 'segmentChanges/some_segment', 'some_api_key', extra_headers={ 'SplitSDKVersion': '1.0', 'SplitSDKMachineIP': '1.2.3.4', @@ -39,9 +39,9 @@ def test_fetch_segment_changes(self, mocker): query={'since': 123})] httpclient.reset_mock() - response = segment_api.fetch_segment('some_segment', 123, FetchOptions(True, 123)) + response = segment_api.fetch_segment('some_segment', 123, FetchOptions(True, 123, None, None, None)) assert response['prop1'] == 'value1' - assert httpclient.get.mock_calls == [mocker.call('sdk', '/segmentChanges/some_segment', 'some_api_key', + assert httpclient.get.mock_calls == [mocker.call('sdk', 'segmentChanges/some_segment', 'some_api_key', extra_headers={ 'SplitSDKVersion': '1.0', 'SplitSDKMachineIP': '1.2.3.4', @@ -58,3 +58,76 @@ def raise_exception(*args, **kwargs): response = segment_api.fetch_segment('some_segment', 123, FetchOptions()) assert exc_info.type == APIException assert exc_info.value.message == 'some_message' + + +class SegmentAPIAsyncTests(object): + """Segment async API test cases.""" + + @pytest.mark.asyncio + async def test_fetch_segment_changes(self, mocker): + """Test segment changes fetching API call.""" + httpclient = mocker.Mock(spec=client.HttpClientAsync) + segment_api = segments.SegmentsAPIAsync(httpclient, 'some_api_key', SdkMetadata('1.0', 'some', '1.2.3.4'), mocker.Mock()) + + self.verb = None + self.url = None + self.key = None + self.headers = None + self.query = None + async def get(verb, url, key, query, extra_headers): + self.url = url + self.verb = verb + self.key = key + self.headers = extra_headers + self.query = query + return client.HttpResponse(200, '{"prop1": "value1"}', {}) + httpclient.get = get + + response = await segment_api.fetch_segment('some_segment', 123, FetchOptions(None, None, None, None, None)) + assert response['prop1'] == 'value1' + assert self.verb == 'sdk' + assert self.url == 'segmentChanges/some_segment' + assert self.key == 'some_api_key' + assert self.headers == { + 'SplitSDKVersion': '1.0', + 'SplitSDKMachineIP': '1.2.3.4', + 'SplitSDKMachineName': 'some' + } + assert self.query == {'since': 123} + + httpclient.reset_mock() + response = await segment_api.fetch_segment('some_segment', 123, FetchOptions(True, None, None, None, None)) + assert response['prop1'] == 'value1' + assert self.verb == 'sdk' + assert self.url == 'segmentChanges/some_segment' + assert self.key == 'some_api_key' + assert self.headers == { + 'SplitSDKVersion': '1.0', + 'SplitSDKMachineIP': '1.2.3.4', + 'SplitSDKMachineName': 'some', + 'Cache-Control': 'no-cache' + } + assert self.query == {'since': 123} + + httpclient.reset_mock() + response = await segment_api.fetch_segment('some_segment', 123, FetchOptions(True, 123, None, None, None)) + assert response['prop1'] == 'value1' + assert self.verb == 'sdk' + assert self.url == 'segmentChanges/some_segment' + assert self.key == 'some_api_key' + assert self.headers == { + 'SplitSDKVersion': '1.0', + 'SplitSDKMachineIP': '1.2.3.4', + 'SplitSDKMachineName': 'some', + 'Cache-Control': 'no-cache' + } + assert self.query == {'since': 123, 'till': 123} + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message') + httpclient.get = raise_exception + with pytest.raises(APIException) as exc_info: + response = await segment_api.fetch_segment('some_segment', 123, FetchOptions(None, None, None, None, None)) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' diff --git a/tests/api/test_splits_api.py b/tests/api/test_splits_api.py index 5e914712..c9aeee8b 100644 --- a/tests/api/test_splits_api.py +++ b/tests/api/test_splits_api.py @@ -1,60 +1,332 @@ """Split API tests module.""" import pytest +import unittest.mock as mock +import time from splitio.api import splits, client, APIException from splitio.api.commons import FetchOptions from splitio.client.util import SdkMetadata - class SplitAPITests(object): """Split API test cases.""" def test_fetch_split_changes(self, mocker): """Test split changes fetching API call.""" httpclient = mocker.Mock(spec=client.HttpClient) - httpclient.get.return_value = client.HttpResponse(200, '{"prop1": "value1"}') - split_api = splits.SplitsAPI(httpclient, 'some_api_key', SdkMetadata('1.0', 'some', '1.2.3.4')) + httpclient.get.return_value = client.HttpResponse(200, '{"prop1": "value1"}', {}) + split_api = splits.SplitsAPI(httpclient, 'some_api_key', SdkMetadata('1.0', 'some', '1.2.3.4'), mocker.Mock()) - response = split_api.fetch_splits(123, FetchOptions()) + response = split_api.fetch_splits(123, -1, FetchOptions(False, None, None, 'set1,set2')) assert response['prop1'] == 'value1' - assert httpclient.get.mock_calls == [mocker.call('sdk', '/splitChanges', 'some_api_key', + assert httpclient.get.mock_calls == [mocker.call('sdk', 'splitChanges', 'some_api_key', extra_headers={ 'SplitSDKVersion': '1.0', 'SplitSDKMachineIP': '1.2.3.4', 'SplitSDKMachineName': 'some' }, - query={'since': 123})] + query={'s': '1.3', 'since': 123, 'rbSince': -1, 'sets': 'set1,set2'})] httpclient.reset_mock() - response = split_api.fetch_splits(123, FetchOptions(True)) + response = split_api.fetch_splits(123, 1, FetchOptions(True, 123, None,'set3')) assert response['prop1'] == 'value1' - assert httpclient.get.mock_calls == [mocker.call('sdk', '/splitChanges', 'some_api_key', + assert httpclient.get.mock_calls == [mocker.call('sdk', 'splitChanges', 'some_api_key', extra_headers={ 'SplitSDKVersion': '1.0', 'SplitSDKMachineIP': '1.2.3.4', 'SplitSDKMachineName': 'some', 'Cache-Control': 'no-cache' }, - query={'since': 123})] + query={'s': '1.3', 'since': 123, 'rbSince': 1, 'till': 123, 'sets': 'set3'})] httpclient.reset_mock() - response = split_api.fetch_splits(123, FetchOptions(True, 123)) + response = split_api.fetch_splits(123, 122, FetchOptions(True, 123, None, 'set3')) assert response['prop1'] == 'value1' - assert httpclient.get.mock_calls == [mocker.call('sdk', '/splitChanges', 'some_api_key', + assert httpclient.get.mock_calls == [mocker.call('sdk', 'splitChanges', 'some_api_key', extra_headers={ 'SplitSDKVersion': '1.0', 'SplitSDKMachineIP': '1.2.3.4', 'SplitSDKMachineName': 'some', 'Cache-Control': 'no-cache' }, - query={'since': 123, 'till': 123})] + query={'s': '1.3', 'since': 123, 'rbSince': 122, 'till': 123, 'sets': 'set3'})] httpclient.reset_mock() def raise_exception(*args, **kwargs): raise client.HttpClientException('some_message') httpclient.get.side_effect = raise_exception with pytest.raises(APIException) as exc_info: - response = split_api.fetch_splits(123, FetchOptions()) + response = split_api.fetch_splits(123, 12, FetchOptions()) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' + + def test_old_spec(self, mocker): + """Test old split changes fetching API call.""" + httpclient = mocker.Mock(spec=client.HttpClient) + self.counter = 0 + self.query = [] + def get(sdk, splitChanges, sdk_key, extra_headers, query): + self.counter += 1 + self.query.append(query) + if self.counter == 1: + return client.HttpResponse(400, 'error', {}) + if self.counter == 2: + return client.HttpResponse(200, '{"splits": [], "since": 123, "till": 456}', {}) + + httpclient.get = get + split_api = splits.SplitsAPI(httpclient, 'some_api_key', SdkMetadata('1.0', 'some', '1.2.3.4'), mocker.Mock()) + + httpclient.is_sdk_endpoint_overridden.return_value = False + try: + response = split_api.fetch_splits(123, -1, FetchOptions(False, None, None, None)) + except Exception as e: + print(e) + + # no attempt to fetch old spec + assert self.query == [{'s': '1.3', 'since': 123, 'rbSince': -1}] + + httpclient.is_sdk_endpoint_overridden.return_value = True + self.query = [] + self.counter = 0 + response = split_api.fetch_splits(123, -1, FetchOptions(False, None, None, None)) + assert response == {"ff": {"d": [], "s": 123, "t": 456}, "rbs": {"d": [], "s": -1, "t": -1}} + assert self.query == [{'s': '1.3', 'since': 123, 'rbSince': -1}, {'s': '1.1', 'since': 123}] + assert not split_api.clear_storage + + def test_switch_to_new_spec(self, mocker): + """Test old split changes fetching API call.""" + httpclient = mocker.Mock(spec=client.HttpClient) + self.counter = 0 + self.query = [] + def get(sdk, splitChanges, sdk_key, extra_headers, query): + self.counter += 1 + self.query.append(query) + if self.counter == 1: + return client.HttpResponse(400, 'error', {}) + if self.counter == 2: + return client.HttpResponse(200, '{"splits": [], "since": 123, "till": 456}', {}) + if self.counter == 3: + return client.HttpResponse(200, '{"ff": {"d": [], "s": 123, "t": 456}, "rbs": {"d": [], "s": 123, "t": -1}}', {}) + + httpclient.is_sdk_endpoint_overridden.return_value = True + httpclient.get = get + split_api = splits.SplitsAPI(httpclient, 'some_api_key', SdkMetadata('1.0', 'some', '1.2.3.4'), mocker.Mock()) + response = split_api.fetch_splits(123, -1, FetchOptions(False, None, None, None)) + assert response == {"ff": {"d": [], "s": 123, "t": 456}, "rbs": {"d": [], "s": -1, "t": -1}} + assert self.query == [{'s': '1.3', 'since': 123, 'rbSince': -1}, {'s': '1.1', 'since': 123}] + assert not split_api.clear_storage + + time.sleep(1) + splits._PROXY_CHECK_INTERVAL_MILLISECONDS_SS = 10 + response = split_api.fetch_splits(123, -1, FetchOptions(False, None, None, None)) + assert self.query[2] == {'s': '1.3', 'since': 123, 'rbSince': -1} + assert response == {"ff": {"d": [], "s": 123, "t": 456}, "rbs": {"d": [], "s": 123, "t": -1}} + assert split_api.clear_storage + + def test_using_old_spec_since(self, mocker): + """Test using old_spec_since variable.""" + httpclient = mocker.Mock(spec=client.HttpClient) + self.counter = 0 + self.query = [] + def get(sdk, splitChanges, sdk_key, extra_headers, query): + self.counter += 1 + self.query.append(query) + if self.counter == 1: + return client.HttpResponse(400, 'error', {}) + if self.counter == 2: + return client.HttpResponse(200, '{"splits": [], "since": 123, "till": 456}', {}) + if self.counter == 3: + return client.HttpResponse(400, 'error', {}) + if self.counter == 4: + return client.HttpResponse(200, '{"splits": [], "since": 456, "till": 456}', {}) + + httpclient.is_sdk_endpoint_overridden.return_value = True + httpclient.get = get + split_api = splits.SplitsAPI(httpclient, 'some_api_key', SdkMetadata('1.0', 'some', '1.2.3.4'), mocker.Mock()) + response = split_api.fetch_splits(123, -1, FetchOptions(False, None, None, None)) + assert response == {"ff": {"d": [], "s": 123, "t": 456}, "rbs": {"d": [], "s": -1, "t": -1}} + assert self.query == [{'s': '1.3', 'since': 123, 'rbSince': -1}, {'s': '1.1', 'since': 123}] + assert not split_api.clear_storage + + time.sleep(1) + splits._PROXY_CHECK_INTERVAL_MILLISECONDS_SS = 10 + + response = split_api.fetch_splits(456, -1, FetchOptions(False, None, None, None)) + time.sleep(1) + splits._PROXY_CHECK_INTERVAL_MILLISECONDS_SS = 1000000 + assert self.query[2] == {'s': '1.3', 'since': 456, 'rbSince': -1} + assert self.query[3] == {'s': '1.1', 'since': 456} + assert response == {"ff": {"d": [], "s": 456, "t": 456}, "rbs": {"d": [], "s": -1, "t": -1}} + +class SplitAPIAsyncTests(object): + """Split async API test cases.""" + + @pytest.mark.asyncio + async def test_fetch_split_changes(self, mocker): + """Test split changes fetching API call.""" + httpclient = mocker.Mock(spec=client.HttpClientAsync) + split_api = splits.SplitsAPIAsync(httpclient, 'some_api_key', SdkMetadata('1.0', 'some', '1.2.3.4'), mocker.Mock()) + self.verb = None + self.url = None + self.key = None + self.headers = None + self.query = None + async def get(verb, url, key, query, extra_headers): + self.url = url + self.verb = verb + self.key = key + self.headers = extra_headers + self.query = query + return client.HttpResponse(200, '{"prop1": "value1"}', {}) + httpclient.get = get + + response = await split_api.fetch_splits(123, -1, FetchOptions(False, None, None, 'set1,set2')) + assert response['prop1'] == 'value1' + assert self.verb == 'sdk' + assert self.url == 'splitChanges' + assert self.key == 'some_api_key' + assert self.headers == { + 'SplitSDKVersion': '1.0', + 'SplitSDKMachineIP': '1.2.3.4', + 'SplitSDKMachineName': 'some' + } + assert self.query == {'s': '1.3', 'since': 123, 'rbSince': -1, 'sets': 'set1,set2'} + + httpclient.reset_mock() + response = await split_api.fetch_splits(123, 1, FetchOptions(True, 123, None, 'set3')) + assert response['prop1'] == 'value1' + assert self.verb == 'sdk' + assert self.url == 'splitChanges' + assert self.key == 'some_api_key' + assert self.headers == { + 'SplitSDKVersion': '1.0', + 'SplitSDKMachineIP': '1.2.3.4', + 'SplitSDKMachineName': 'some', + 'Cache-Control': 'no-cache' + } + assert self.query == {'s': '1.3', 'since': 123, 'rbSince': 1, 'till': 123, 'sets': 'set3'} + + httpclient.reset_mock() + response = await split_api.fetch_splits(123, 122, FetchOptions(True, 123, None)) + assert response['prop1'] == 'value1' + assert self.verb == 'sdk' + assert self.url == 'splitChanges' + assert self.key == 'some_api_key' + assert self.headers == { + 'SplitSDKVersion': '1.0', + 'SplitSDKMachineIP': '1.2.3.4', + 'SplitSDKMachineName': 'some', + 'Cache-Control': 'no-cache' + } + assert self.query == {'s': '1.3', 'since': 123, 'rbSince': 122, 'till': 123} + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message') + httpclient.get = raise_exception + with pytest.raises(APIException) as exc_info: + response = await split_api.fetch_splits(123, 12, FetchOptions()) assert exc_info.type == APIException assert exc_info.value.message == 'some_message' + + @pytest.mark.asyncio + async def test_old_spec(self, mocker): + """Test old split changes fetching API call.""" + httpclient = mocker.Mock(spec=client.HttpClientAsync) + self.counter = 0 + self.query = [] + async def get(sdk, splitChanges, sdk_key, extra_headers, query): + self.counter += 1 + self.query.append(query) + if self.counter == 1: + return client.HttpResponse(400, 'error', {}) + if self.counter == 2: + return client.HttpResponse(200, '{"splits": [], "since": 123, "till": 456}', {}) + + httpclient.is_sdk_endpoint_overridden.return_value = True + httpclient.get = get + split_api = splits.SplitsAPIAsync(httpclient, 'some_api_key', SdkMetadata('1.0', 'some', '1.2.3.4'), mocker.Mock()) + + httpclient.is_sdk_endpoint_overridden.return_value = False + try: + response = await split_api.fetch_splits(123, -1, FetchOptions(False, None, None, None)) + except Exception as e: + print(e) + + # no attempt to fetch old spec + assert self.query == [{'s': '1.3', 'since': 123, 'rbSince': -1}] + + httpclient.is_sdk_endpoint_overridden.return_value = True + self.query = [] + self.counter = 0 + response = await split_api.fetch_splits(123, -1, FetchOptions(False, None, None, None)) + assert response == {"ff": {"d": [], "s": 123, "t": 456}, "rbs": {"d": [], "s": -1, "t": -1}} + assert self.query == [{'s': '1.3', 'since': 123, 'rbSince': -1}, {'s': '1.1', 'since': 123}] + assert not split_api.clear_storage + + @pytest.mark.asyncio + async def test_switch_to_new_spec(self, mocker): + """Test old split changes fetching API call.""" + httpclient = mocker.Mock(spec=client.HttpClientAsync) + self.counter = 0 + self.query = [] + async def get(sdk, splitChanges, sdk_key, extra_headers, query): + self.counter += 1 + self.query.append(query) + if self.counter == 1: + return client.HttpResponse(400, 'error', {}) + if self.counter == 2: + return client.HttpResponse(200, '{"splits": [], "since": 123, "till": 456}', {}) + if self.counter == 3: + return client.HttpResponse(200, '{"ff": {"d": [], "s": 123, "t": 456}, "rbs": {"d": [], "s": 123, "t": -1}}', {}) + + httpclient.is_sdk_endpoint_overridden.return_value = True + httpclient.get = get + split_api = splits.SplitsAPIAsync(httpclient, 'some_api_key', SdkMetadata('1.0', 'some', '1.2.3.4'), mocker.Mock()) + response = await split_api.fetch_splits(123, -1, FetchOptions(False, None, None, None)) + assert response == {"ff": {"d": [], "s": 123, "t": 456}, "rbs": {"d": [], "s": -1, "t": -1}} + assert self.query == [{'s': '1.3', 'since': 123, 'rbSince': -1}, {'s': '1.1', 'since': 123}] + assert not split_api.clear_storage + + time.sleep(1) + splits._PROXY_CHECK_INTERVAL_MILLISECONDS_SS = 10 + response = await split_api.fetch_splits(123, -1, FetchOptions(False, None, None, None)) + assert self.query[2] == {'s': '1.3', 'since': 123, 'rbSince': -1} + assert response == {"ff": {"d": [], "s": 123, "t": 456}, "rbs": {"d": [], "s": 123, "t": -1}} + assert split_api.clear_storage + + @pytest.mark.asyncio + async def test_using_old_spec_since(self, mocker): + """Test using old_spec_since variable.""" + httpclient = mocker.Mock(spec=client.HttpClient) + self.counter = 0 + self.query = [] + async def get(sdk, splitChanges, sdk_key, extra_headers, query): + self.counter += 1 + self.query.append(query) + if self.counter == 1: + return client.HttpResponse(400, 'error', {}) + if self.counter == 2: + return client.HttpResponse(200, '{"splits": [], "since": 123, "till": 456}', {}) + if self.counter == 3: + return client.HttpResponse(400, 'error', {}) + if self.counter == 4: + return client.HttpResponse(200, '{"splits": [], "since": 456, "till": 456}', {}) + + httpclient.is_sdk_endpoint_overridden.return_value = True + httpclient.get = get + split_api = splits.SplitsAPIAsync(httpclient, 'some_api_key', SdkMetadata('1.0', 'some', '1.2.3.4'), mocker.Mock()) + response = await split_api.fetch_splits(123, -1, FetchOptions(False, None, None, None)) + assert response == {"ff": {"d": [], "s": 123, "t": 456}, "rbs": {"d": [], "s": -1, "t": -1}} + assert self.query == [{'s': '1.3', 'since': 123, 'rbSince': -1}, {'s': '1.1', 'since': 123}] + assert not split_api.clear_storage + + time.sleep(1) + splits._PROXY_CHECK_INTERVAL_MILLISECONDS_SS = 10 + + response = await split_api.fetch_splits(456, -1, FetchOptions(False, None, None, None)) + time.sleep(1) + splits._PROXY_CHECK_INTERVAL_MILLISECONDS_SS = 1000000 + assert self.query[2] == {'s': '1.3', 'since': 456, 'rbSince': -1} + assert self.query[3] == {'s': '1.1', 'since': 456} + assert response == {"ff": {"d": [], "s": 456, "t": 456}, "rbs": {"d": [], "s": -1, "t": -1}} diff --git a/tests/api/test_telemetry_api.py b/tests/api/test_telemetry_api.py new file mode 100644 index 00000000..5a857789 --- /dev/null +++ b/tests/api/test_telemetry_api.py @@ -0,0 +1,266 @@ +"""Impressions API tests module.""" + +import pytest +import unittest.mock as mock + +from splitio.api import telemetry, client, APIException +#from splitio.models.telemetry import +from splitio.client.util import get_metadata +from splitio.client.config import DEFAULT_CONFIG +from splitio.version import __version__ +from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemoryTelemetryStorageAsync + + +class TelemetryAPITests(object): + """Telemetry API test cases.""" + + def test_record_unique_keys(self, mocker): + """Test telemetry posting unique keys.""" + httpclient = mocker.Mock(spec=client.HttpClient) + httpclient.post.return_value = client.HttpResponse(200, '', {}) + uniques = {'keys': [1, 2, 3]} + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': True, 'machineName': 'some_machine_name', 'machineIp': '123.123.123.123'}) + sdk_metadata = get_metadata(cfg) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_api = telemetry.TelemetryAPI(httpclient, 'some_api_key', sdk_metadata, telemetry_runtime_producer) + response = telemetry_api.record_unique_keys(uniques) + + call_made = httpclient.post.mock_calls[0] + + # validate positional arguments + assert call_made[1] == ('telemetry', 'v1/keys/ss', 'some_api_key') + + # validate key-value args (headers) + assert call_made[2]['extra_headers'] == { + 'SplitSDKVersion': 'python-%s' % __version__, + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name' + } + + # validate key-value args (body) + assert call_made[2]['body'] == uniques + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message') + httpclient.post.side_effect = raise_exception + with pytest.raises(APIException) as exc_info: + response = telemetry_api.record_unique_keys(uniques) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' + + def test_record_init(self, mocker): + """Test telemetry posting init configs.""" + httpclient = mocker.Mock(spec=client.HttpClient) + httpclient.post.return_value = client.HttpResponse(200, '', {}) + uniques = {'keys': [1, 2, 3]} + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': True, 'machineName': 'some_machine_name', 'machineIp': '123.123.123.123'}) + sdk_metadata = get_metadata(cfg) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_api = telemetry.TelemetryAPI(httpclient, 'some_api_key', sdk_metadata, telemetry_runtime_producer) + response = telemetry_api.record_init(uniques) + + call_made = httpclient.post.mock_calls[0] + + # validate positional arguments + assert call_made[1] == ('telemetry', 'v1/metrics/config', 'some_api_key') + + # validate key-value args (headers) + assert call_made[2]['extra_headers'] == { + 'SplitSDKVersion': 'python-%s' % __version__, + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name' + } + + # validate key-value args (body) + assert call_made[2]['body'] == uniques + + def test_record_stats(self, mocker): + """Test telemetry posting stats.""" + httpclient = mocker.Mock(spec=client.HttpClient) + httpclient.post.return_value = client.HttpResponse(200, '', {}) + uniques = {'keys': [1, 2, 3]} + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': True, 'machineName': 'some_machine_name', 'machineIp': '123.123.123.123'}) + sdk_metadata = get_metadata(cfg) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_api = telemetry.TelemetryAPI(httpclient, 'some_api_key', sdk_metadata, telemetry_runtime_producer) + response = telemetry_api.record_stats(uniques) + + call_made = httpclient.post.mock_calls[0] + + # validate positional arguments + assert call_made[1] == ('telemetry', 'v1/metrics/usage', 'some_api_key') + + # validate key-value args (headers) + assert call_made[2]['extra_headers'] == { + 'SplitSDKVersion': 'python-%s' % __version__, + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name' + } + + # validate key-value args (body) + assert call_made[2]['body'] == uniques + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message') + httpclient.post.side_effect = raise_exception + with pytest.raises(APIException) as exc_info: + response = telemetry_api.record_stats(uniques) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' + + +class TelemetryAPIAsyncTests(object): + """Telemetry API test cases.""" + + @pytest.mark.asyncio + async def test_record_unique_keys(self, mocker): + """Test telemetry posting unique keys.""" + httpclient = mocker.Mock(spec=client.HttpClientAsync) + uniques = {'keys': [1, 2, 3]} + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': True, 'machineName': 'some_machine_name', 'machineIp': '123.123.123.123'}) + sdk_metadata = get_metadata(cfg) + telemetry_storage = InMemoryTelemetryStorageAsync() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_api = telemetry.TelemetryAPIAsync(httpclient, 'some_api_key', sdk_metadata, telemetry_runtime_producer) + self.verb = None + self.url = None + self.key = None + self.headers = None + self.body = None + async def post(verb, url, key, body, extra_headers): + self.url = url + self.verb = verb + self.key = key + self.headers = extra_headers + self.body = body + return client.HttpResponse(200, '', {}) + httpclient.post = post + + response = await telemetry_api.record_unique_keys(uniques) + assert self.verb == 'telemetry' + assert self.url == 'v1/keys/ss' + assert self.key == 'some_api_key' + + # validate key-value args (headers) + assert self.headers == { + 'SplitSDKVersion': 'python-%s' % __version__, + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name' + } + + # validate key-value args (body) + assert self.body == uniques + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message') + httpclient.post = raise_exception + with pytest.raises(APIException) as exc_info: + response = await telemetry_api.record_unique_keys(uniques) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' + + @pytest.mark.asyncio + async def test_record_init(self, mocker): + """Test telemetry posting unique keys.""" + httpclient = mocker.Mock(spec=client.HttpClientAsync) + uniques = {'keys': [1, 2, 3]} + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': True, 'machineName': 'some_machine_name', 'machineIp': '123.123.123.123'}) + sdk_metadata = get_metadata(cfg) + telemetry_storage = InMemoryTelemetryStorageAsync() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_api = telemetry.TelemetryAPIAsync(httpclient, 'some_api_key', sdk_metadata, telemetry_runtime_producer) + self.verb = None + self.url = None + self.key = None + self.headers = None + self.body = None + async def post(verb, url, key, body, extra_headers): + self.url = url + self.verb = verb + self.key = key + self.headers = extra_headers + self.body = body + return client.HttpResponse(200, '', {}) + httpclient.post = post + + response = await telemetry_api.record_init(uniques) + assert self.verb == 'telemetry' + assert self.url == 'v1/metrics/config' + assert self.key == 'some_api_key' + + # validate key-value args (headers) + assert self.headers == { + 'SplitSDKVersion': 'python-%s' % __version__, + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name' + } + + # validate key-value args (body) + assert self.body == uniques + + @pytest.mark.asyncio + async def test_record_stats(self, mocker): + """Test telemetry posting unique keys.""" + httpclient = mocker.Mock(spec=client.HttpClientAsync) + uniques = {'keys': [1, 2, 3]} + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': True, 'machineName': 'some_machine_name', 'machineIp': '123.123.123.123'}) + sdk_metadata = get_metadata(cfg) + telemetry_storage = InMemoryTelemetryStorageAsync() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_api = telemetry.TelemetryAPIAsync(httpclient, 'some_api_key', sdk_metadata, telemetry_runtime_producer) + self.verb = None + self.url = None + self.key = None + self.headers = None + self.body = None + async def post(verb, url, key, body, extra_headers): + self.url = url + self.verb = verb + self.key = key + self.headers = extra_headers + self.body = body + return client.HttpResponse(200, '', {}) + httpclient.post = post + + response = await telemetry_api.record_stats(uniques) + assert self.verb == 'telemetry' + assert self.url == 'v1/metrics/usage' + assert self.key == 'some_api_key' + + # validate key-value args (headers) + assert self.headers == { + 'SplitSDKVersion': 'python-%s' % __version__, + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name' + } + + # validate key-value args (body) + assert self.body == uniques + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message') + httpclient.post = raise_exception + with pytest.raises(APIException) as exc_info: + response = await telemetry_api.record_stats(uniques) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' diff --git a/tests/api/test_util.py b/tests/api/test_util.py index c245c157..51876f52 100644 --- a/tests/api/test_util.py +++ b/tests/api/test_util.py @@ -1,9 +1,13 @@ """Split API tests module.""" import pytest +import unittest.mock as mock -from splitio.api.commons import headers_from_metadata +from splitio.api import headers_from_metadata from splitio.client.util import SdkMetadata +from splitio.engine.telemetry import TelemetryStorageProducer +from splitio.storage.inmemmory import InMemoryTelemetryStorage +from splitio.models.telemetry import HTTPExceptionsAndLatencies class UtilTests(object): @@ -34,5 +38,3 @@ def test_headers_from_metadata(self, mocker): assert 'SplitSDKMachineIP' not in metadata assert 'SplitSDKMachineName' not in metadata assert 'SplitSDKClientKey' not in metadata - - diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 057a9ddc..1efd4143 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1,21 +1,32 @@ """SDK main client test module.""" # pylint: disable=no-self-use,protected-access -import json -import os -from splitio.client.client import Client, _LOGGER as _logger, CONTROL -from splitio.client.factory import SplitFactory -from splitio.engine.evaluator import Evaluator +import unittest.mock as mock +import pytest +import queue +import asyncio + +from splitio.client.client import Client, _LOGGER as _logger, CONTROL, ClientAsync, EvaluationOptions +from splitio.client.factory import SplitFactory, Status as FactoryStatus, SplitFactoryAsync +from splitio.events.events_manager import EventsManager, EventsManagerAsync +from splitio.models.fallback_config import FallbackTreatmentsConfiguration, FallbackTreatmentCalculator +from splitio.models.fallback_treatment import FallbackTreatment from splitio.models.impressions import Impression, Label -from splitio.models.events import Event, EventWrapper -from splitio.storage import EventStorage, ImpressionStorage, SegmentStorage, SplitStorage +from splitio.models.events import Event, EventWrapper, SdkEvent +from splitio.storage import EventStorage, ImpressionStorage, SegmentStorage, SplitStorage, RuleBasedSegmentsStorage from splitio.storage.inmemmory import InMemorySplitStorage, InMemorySegmentStorage, \ - InMemoryImpressionStorage, InMemoryEventStorage -from splitio.models import splits, segments -from splitio.engine.impressions import Manager as ImpressionManager - -# Recorder -from splitio.recorder.recorder import StandardRecorder + InMemoryImpressionStorage, InMemoryTelemetryStorage, InMemorySplitStorageAsync, \ + InMemoryImpressionStorageAsync, InMemorySegmentStorageAsync, InMemoryTelemetryStorageAsync, InMemoryEventStorageAsync, \ + InMemoryRuleBasedSegmentStorage, InMemoryRuleBasedSegmentStorageAsync +from splitio.models.splits import Split, Status, from_raw +from splitio.engine.impressions.impressions import Manager as ImpressionManager +from splitio.engine.impressions.manager import Counter as ImpressionsCounter +from splitio.engine.impressions.unique_keys_tracker import UniqueKeysTracker, UniqueKeysTrackerAsync +from splitio.engine.telemetry import TelemetryStorageConsumer, TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.engine.evaluator import Evaluator, EvaluationContext +from splitio.recorder.recorder import StandardRecorder, StandardRecorderAsync +from splitio.engine.impressions.strategies import StrategyDebugMode, StrategyNoneMode, StrategyOptimizedMode +from tests.integration import splits_json class ClientTests(object): # pylint: disable=too-few-public-methods @@ -23,174 +34,290 @@ class ClientTests(object): # pylint: disable=too-few-public-methods def test_get_treatment(self, mocker): """Test get_treatment execution paths.""" - split_storage = mocker.Mock(spec=SplitStorage) - segment_storage = mocker.Mock(spec=SegmentStorage) - impression_storage = mocker.Mock(spec=ImpressionStorage) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) event_storage = mocker.Mock(spec=EventStorage) - def _get_storage_mock(name): - return { - 'splits': split_storage, - 'segments': segment_storage, - 'impressions': impression_storage, - 'events': event_storage, - }[name] - destroyed_property = mocker.PropertyMock() destroyed_property.return_value = False - factory = mocker.Mock(spec=SplitFactory) - factory._get_storage.side_effect = _get_storage_mock - factory._waiting_fork.return_value = False - type(factory).destroyed = destroyed_property - mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) - impmanager = mocker.Mock(spec=ImpressionManager) - recorder = StandardRecorder(impmanager, event_storage, impression_storage) - client = Client(factory, recorder, True) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer(), + unique_keys_tracker=UniqueKeysTracker(), + imp_counter=ImpressionsCounter()) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + TelemetrySubmitterMock(), + ) + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(factory).ready = ready_property + factory.block_until_ready(5) + + split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0])], [], -1) + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) client._evaluator = mocker.Mock(spec=Evaluator) - client._evaluator.evaluate_feature.return_value = { + client._evaluator.eval_with_context.return_value = { 'treatment': 'on', 'configurations': None, 'impression': { 'label': 'some_label', 'change_number': 123 }, + 'impressions_disabled': False } _logger = mocker.Mock() - - assert client.get_treatment('some_key', 'some_feature') == 'on' - assert mocker.call( - [(Impression('some_key', 'some_feature', 'on', 'some_label', 123, None, 1000), None)] - ) in impmanager.process_impressions.mock_calls + assert client.get_treatment('some_key', 'SPLIT_2') == 'on' + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None)] assert _logger.mock_calls == [] # Test with client not ready ready_property = mocker.PropertyMock() ready_property.return_value = False type(factory).ready = ready_property - impmanager.process_impressions.reset_mock() - assert client.get_treatment('some_key', 'some_feature', {'some_attribute': 1}) == 'control' - assert mocker.call( - [(Impression('some_key', 'some_feature', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY), {'some_attribute': 1})] - ) in impmanager.process_impressions.mock_calls + # pytest.set_trace() + assert client.get_treatment('some_key', 'SPLIT_2', {'some_attribute': 1}) == 'control' + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'control', Label.NOT_READY, None, None, 1000, None, None)] # Test with exception: ready_property.return_value = True - split_storage.get_change_number.return_value = -1 - def _raise(*_): - raise Exception('something') - client._evaluator.evaluate_feature.side_effect = _raise - assert client.get_treatment('some_key', 'some_feature') == 'control' - assert mocker.call( - [(Impression('some_key', 'some_feature', 'control', 'exception', -1, None, 1000), None)] - ) in impmanager.process_impressions.mock_calls + raise RuntimeError('something') + client._evaluator.eval_with_context.side_effect = _raise + assert client.get_treatment('some_key', 'SPLIT_2') == 'control' + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'control', 'exception', None, None, 1000, None, None)] + factory.destroy() def test_get_treatment_with_config(self, mocker): """Test get_treatment execution paths.""" - split_storage = mocker.Mock(spec=SplitStorage) - segment_storage = mocker.Mock(spec=SegmentStorage) - impression_storage = mocker.Mock(spec=ImpressionStorage) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) event_storage = mocker.Mock(spec=EventStorage) - def _get_storage_mock(name): - return { - 'splits': split_storage, - 'segments': segment_storage, - 'impressions': impression_storage, - 'events': event_storage, - }[name] - destroyed_property = mocker.PropertyMock() destroyed_property.return_value = False - factory = mocker.Mock(spec=SplitFactory) - factory._get_storage.side_effect = _get_storage_mock - factory._waiting_fork.return_value = False - type(factory).destroyed = destroyed_property + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'impressions': impression_storage, + 'rule_based_segments': rb_segment_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) - impmanager = mocker.Mock(spec=ImpressionManager) - recorder = StandardRecorder(impmanager, event_storage, impression_storage) - client = Client(factory, recorder, True) + split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0])], [], -1) + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) client._evaluator = mocker.Mock(spec=Evaluator) - client._evaluator.evaluate_feature.return_value = { + client._evaluator.eval_with_context.return_value = { 'treatment': 'on', 'configurations': '{"some_config": True}', 'impression': { 'label': 'some_label', 'change_number': 123 - } + }, + 'impressions_disabled': False } _logger = mocker.Mock() client._send_impression_to_listener = mocker.Mock() assert client.get_treatment_with_config( 'some_key', - 'some_feature' + 'SPLIT_2' ) == ('on', '{"some_config": True}') - assert mocker.call( - [(Impression('some_key', 'some_feature', 'on', 'some_label', 123, None, 1000), None)] - ) in impmanager.process_impressions.mock_calls + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None)] assert _logger.mock_calls == [] # Test with client not ready ready_property = mocker.PropertyMock() ready_property.return_value = False type(factory).ready = ready_property - impmanager.process_impressions.reset_mock() - assert client.get_treatment_with_config('some_key', 'some_feature', {'some_attribute': 1}) == ('control', None) - assert mocker.call( - [(Impression('some_key', 'some_feature', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY), - {'some_attribute': 1})] - ) in impmanager.process_impressions.mock_calls + assert client.get_treatment_with_config('some_key', 'SPLIT_2', {'some_attribute': 1}) == ('control', None) + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)] # Test with exception: ready_property.return_value = True - split_storage.get_change_number.return_value = -1 def _raise(*_): - raise Exception('something') - client._evaluator.evaluate_feature.side_effect = _raise - assert client.get_treatment_with_config('some_key', 'some_feature') == ('control', None) - assert mocker.call( - [(Impression('some_key', 'some_feature', 'control', 'exception', -1, None, 1000), None)] - ) in impmanager.process_impressions.mock_calls + raise RuntimeError('something') + client._evaluator.eval_with_context.side_effect = _raise + assert client.get_treatment_with_config('some_key', 'SPLIT_2') == ('control', None) + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'control', 'exception', None, None, 1000, None, None)] + factory.destroy() def test_get_treatments(self, mocker): """Test get_treatment execution paths.""" - split_storage = mocker.Mock(spec=SplitStorage) - segment_storage = mocker.Mock(spec=SegmentStorage) - impression_storage = mocker.Mock(spec=ImpressionStorage) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) event_storage = mocker.Mock(spec=EventStorage) + split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0]), from_raw(splits_json['splitChange1_1']['ff']['d'][1])], [], -1) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock(spec=Evaluator) + evaluation = { + 'treatment': 'on', + 'configurations': '{"color": "red"}', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + }, + 'impressions_disabled': False + } + client._evaluator.eval_many_with_context.return_value = { + 'SPLIT_2': evaluation, + 'SPLIT_1': evaluation + } + _logger = mocker.Mock() + client._send_impression_to_listener = mocker.Mock() + treatments = client.get_treatments('key', ['SPLIT_2', 'SPLIT_1']) + assert treatments == {'SPLIT_2': 'on', 'SPLIT_1': 'on'} + + impressions_called = impression_storage.pop_many(100) + assert Impression('key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert Impression('key', 'SPLIT_1', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert _logger.mock_calls == [] + + # Test with client not ready + ready_property = mocker.PropertyMock() + ready_property.return_value = False + type(factory).ready = ready_property + assert client.get_treatments('some_key', ['SPLIT_2'], {'some_attribute': 1}) == {'SPLIT_2': 'control'} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)] + + # Test with exception: + ready_property.return_value = True + + def _raise(*_): + raise RuntimeError('something') + client._evaluator.eval_many_with_context.side_effect = _raise + assert client.get_treatments('key', ['SPLIT_2', 'SPLIT_1']) == {'SPLIT_2': 'control', 'SPLIT_1': 'control'} + factory.destroy() - def _get_storage_mock(name): - return { - 'splits': split_storage, - 'segments': segment_storage, - 'impressions': impression_storage, - 'events': event_storage, - }[name] + def test_get_treatments_by_flag_set(self, mocker): + """Test get_treatment execution paths.""" + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0]), from_raw(splits_json['splitChange1_1']['ff']['d'][1])], [], -1) destroyed_property = mocker.PropertyMock() destroyed_property.return_value = False - factory = mocker.Mock(spec=SplitFactory) - factory._get_storage.side_effect = _get_storage_mock - factory._waiting_fork.return_value = False - type(factory).destroyed = destroyed_property + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) - impmanager = mocker.Mock(spec=ImpressionManager) - recorder = StandardRecorder(impmanager, event_storage, impression_storage) - client = Client(factory, recorder, True) + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) client._evaluator = mocker.Mock(spec=Evaluator) evaluation = { 'treatment': 'on', @@ -198,69 +325,249 @@ def _get_storage_mock(name): 'impression': { 'label': 'some_label', 'change_number': 123 - } + }, + 'impressions_disabled': False } - client._evaluator.evaluate_features.return_value = { - 'f1': evaluation, - 'f2': evaluation + client._evaluator.eval_many_with_context.return_value = { + 'SPLIT_2': evaluation, + 'SPLIT_1': evaluation } _logger = mocker.Mock() client._send_impression_to_listener = mocker.Mock() - assert client.get_treatments('key', ['f1', 'f2']) == {'f1': 'on', 'f2': 'on'} + assert client.get_treatments_by_flag_set('key', 'set_1') == {'SPLIT_2': 'on', 'SPLIT_1': 'on'} - impressions_called = impmanager.process_impressions.mock_calls[0][1][0] - assert (Impression('key', 'f1', 'on', 'some_label', 123, None, 1000), None) in impressions_called - assert (Impression('key', 'f2', 'on', 'some_label', 123, None, 1000), None) in impressions_called + impressions_called = impression_storage.pop_many(100) + assert Impression('key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert Impression('key', 'SPLIT_1', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called assert _logger.mock_calls == [] # Test with client not ready ready_property = mocker.PropertyMock() ready_property.return_value = False type(factory).ready = ready_property - impmanager.process_impressions.reset_mock() - assert client.get_treatments('some_key', ['some_feature'], {'some_attribute': 1}) == {'some_feature': 'control'} - assert mocker.call( - [(Impression('some_key', 'some_feature', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY), {'some_attribute': 1})] - ) in impmanager.process_impressions.mock_calls + assert client.get_treatments_by_flag_set('some_key', 'set_2', {'some_attribute': 1}) == {'SPLIT_1': 'control'} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_1', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)] # Test with exception: ready_property.return_value = True - split_storage.get_change_number.return_value = -1 def _raise(*_): - raise Exception('something') - client._evaluator.evaluate_features.side_effect = _raise - assert client.get_treatments('key', ['f1', 'f2']) == {'f1': 'control', 'f2': 'control'} + raise RuntimeError('something') + client._evaluator.eval_many_with_context.side_effect = _raise + assert client.get_treatments_by_flag_set('key', 'set_1') == {'SPLIT_2': 'control', 'SPLIT_1': 'control'} + factory.destroy() - def test_get_treatments_with_config(self, mocker): + def test_get_treatments_by_flag_sets(self, mocker): """Test get_treatment execution paths.""" - split_storage = mocker.Mock(spec=SplitStorage) - segment_storage = mocker.Mock(spec=SegmentStorage) - impression_storage = mocker.Mock(spec=ImpressionStorage) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) event_storage = mocker.Mock(spec=EventStorage) + split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0]), from_raw(splits_json['splitChange1_1']['ff']['d'][1])], [], -1) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock(spec=Evaluator) + evaluation = { + 'treatment': 'on', + 'configurations': '{"color": "red"}', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + }, + 'impressions_disabled': False + } + client._evaluator.eval_many_with_context.return_value = { + 'SPLIT_2': evaluation, + 'SPLIT_1': evaluation + } + _logger = mocker.Mock() + client._send_impression_to_listener = mocker.Mock() + assert client.get_treatments_by_flag_sets('key', ['set_1']) == {'SPLIT_2': 'on', 'SPLIT_1': 'on'} + + impressions_called = impression_storage.pop_many(100) + assert Impression('key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert Impression('key', 'SPLIT_1', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert _logger.mock_calls == [] + + # Test with client not ready + ready_property = mocker.PropertyMock() + ready_property.return_value = False + type(factory).ready = ready_property + assert client.get_treatments_by_flag_sets('some_key', ['set_2'], {'some_attribute': 1}) == {'SPLIT_1': 'control'} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_1', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)] + + # Test with exception: + ready_property.return_value = True + + def _raise(*_): + raise RuntimeError('something') + client._evaluator.eval_many_with_context.side_effect = _raise + assert client.get_treatments_by_flag_sets('key', ['set_1']) == {'SPLIT_2': 'control', 'SPLIT_1': 'control'} + factory.destroy() - def _get_storage_mock(name): - return { - 'splits': split_storage, - 'segments': segment_storage, - 'impressions': impression_storage, - 'events': event_storage, - }[name] + def test_get_treatments_with_config(self, mocker): + """Test get_treatment execution paths.""" + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0]), from_raw(splits_json['splitChange1_1']['ff']['d'][1])], [], -1) destroyed_property = mocker.PropertyMock() destroyed_property.return_value = False + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock(spec=Evaluator) + evaluation = { + 'treatment': 'on', + 'configurations': '{"color": "red"}', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + }, + 'impressions_disabled': False + } + client._evaluator.eval_many_with_context.return_value = { + 'SPLIT_1': evaluation, + 'SPLIT_2': evaluation + } + _logger = mocker.Mock() + assert client.get_treatments_with_config('key', ['SPLIT_1', 'SPLIT_2']) == { + 'SPLIT_1': ('on', '{"color": "red"}'), + 'SPLIT_2': ('on', '{"color": "red"}') + } + + impressions_called = impression_storage.pop_many(100) + assert Impression('key', 'SPLIT_1', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert Impression('key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert _logger.mock_calls == [] + + # Test with client not ready + ready_property = mocker.PropertyMock() + ready_property.return_value = False + type(factory).ready = ready_property + assert client.get_treatments_with_config('some_key', ['SPLIT_1'], {'some_attribute': 1}) == {'SPLIT_1': ('control', None)} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_1', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)] + + # Test with exception: + ready_property.return_value = True + + def _raise(*_): + raise RuntimeError('something') + client._evaluator.eval_many_with_context.side_effect = _raise + assert client.get_treatments_with_config('key', ['SPLIT_1', 'SPLIT_2']) == { + 'SPLIT_1': ('control', None), + 'SPLIT_2': ('control', None) + } + factory.destroy() + + def test_get_treatments_with_config_by_flag_set(self, mocker): + """Test get_treatment execution paths.""" + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0]), from_raw(splits_json['splitChange1_1']['ff']['d'][1])], [], -1) - factory = mocker.Mock(spec=SplitFactory) - factory._get_storage.side_effect = _get_storage_mock - factory._waiting_fork.return_value = False - type(factory).destroyed = destroyed_property + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + events_queue = queue.Queue() + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) - impmanager = mocker.Mock(spec=ImpressionManager) - recorder = StandardRecorder(impmanager, event_storage, impression_storage) - client = Client(factory, recorder, True) + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) client._evaluator = mocker.Mock(spec=Evaluator) evaluation = { 'treatment': 'on', @@ -268,98 +575,395 @@ def _get_storage_mock(name): 'impression': { 'label': 'some_label', 'change_number': 123 - } + }, + 'impressions_disabled': False } - client._evaluator.evaluate_features.return_value = { - 'f1': evaluation, - 'f2': evaluation + client._evaluator.eval_many_with_context.return_value = { + 'SPLIT_1': evaluation, + 'SPLIT_2': evaluation } _logger = mocker.Mock() - assert client.get_treatments_with_config('key', ['f1', 'f2']) == { - 'f1': ('on', '{"color": "red"}'), - 'f2': ('on', '{"color": "red"}') + assert client.get_treatments_with_config_by_flag_set('key', 'set_1') == { + 'SPLIT_1': ('on', '{"color": "red"}'), + 'SPLIT_2': ('on', '{"color": "red"}') } - impressions_called = impmanager.process_impressions.mock_calls[0][1][0] - assert (Impression('key', 'f1', 'on', 'some_label', 123, None, 1000), None) in impressions_called - assert (Impression('key', 'f2', 'on', 'some_label', 123, None, 1000), None) in impressions_called + impressions_called = impression_storage.pop_many(100) + assert Impression('key', 'SPLIT_1', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert Impression('key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called assert _logger.mock_calls == [] # Test with client not ready ready_property = mocker.PropertyMock() ready_property.return_value = False type(factory).ready = ready_property - impmanager.process_impressions.reset_mock() - assert client.get_treatments_with_config('some_key', ['some_feature'], {'some_attribute': 1}) == {'some_feature': ('control', None)} - assert mocker.call( - [(Impression('some_key', 'some_feature', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY), {'some_attribute': 1})] - ) in impmanager.process_impressions.mock_calls + assert client.get_treatments_with_config_by_flag_set('some_key', 'set_2', {'some_attribute': 1}) == {'SPLIT_1': ('control', None)} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_1', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)] # Test with exception: ready_property.return_value = True - split_storage.get_change_number.return_value = -1 def _raise(*_): - raise Exception('something') - client._evaluator.evaluate_features.side_effect = _raise - assert client.get_treatments_with_config('key', ['f1', 'f2']) == { - 'f1': ('control', None), - 'f2': ('control', None) + raise RuntimeError('something') + client._evaluator.eval_many_with_context.side_effect = _raise + assert client.get_treatments_with_config_by_flag_set('key', 'set_1') == {'SPLIT_1': ('control', None), 'SPLIT_2': ('control', None)} + factory.destroy() + + def test_get_treatments_with_config_by_flag_sets(self, mocker): + """Test get_treatment execution paths.""" + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0]), from_raw(splits_json['splitChange1_1']['ff']['d'][1])], [], -1) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock(spec=Evaluator) + evaluation = { + 'treatment': 'on', + 'configurations': '{"color": "red"}', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + }, + 'impressions_disabled': False } + client._evaluator.eval_many_with_context.return_value = { + 'SPLIT_1': evaluation, + 'SPLIT_2': evaluation + } + _logger = mocker.Mock() + assert client.get_treatments_with_config_by_flag_sets('key', ['set_1']) == { + 'SPLIT_1': ('on', '{"color": "red"}'), + 'SPLIT_2': ('on', '{"color": "red"}') + } + + impressions_called = impression_storage.pop_many(100) + assert Impression('key', 'SPLIT_1', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert Impression('key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert _logger.mock_calls == [] + + # Test with client not ready + ready_property = mocker.PropertyMock() + ready_property.return_value = False + type(factory).ready = ready_property + assert client.get_treatments_with_config_by_flag_sets('some_key', ['set_2'], {'some_attribute': 1}) == {'SPLIT_1': ('control', None)} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_1', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)] + + # Test with exception: + ready_property.return_value = True + + def _raise(*_): + raise RuntimeError('something') + client._evaluator.eval_many_with_context.side_effect = _raise + assert client.get_treatments_with_config_by_flag_sets('key', ['set_1']) == {'SPLIT_1': ('control', None), 'SPLIT_2': ('control', None)} + factory.destroy() + + def test_impression_toggle_optimized(self, mocker): + """Test get_treatment execution paths.""" + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + impmanager = ImpressionManager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + TelemetrySubmitterMock(), + ) + + factory.block_until_ready(5) + + split_storage.update([ + from_raw(splits_json['splitChange1_1']['ff']['d'][0]), + from_raw(splits_json['splitChange1_1']['ff']['d'][1]), + from_raw(splits_json['splitChange1_1']['ff']['d'][2]) + ], [], -1) + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) + assert client.get_treatment('some_key', 'SPLIT_1') == 'off' + assert client.get_treatment('some_key', 'SPLIT_2') == 'on' + assert client.get_treatment('some_key', 'SPLIT_3') == 'on' + + impressions = impression_storage.pop_many(100) + assert len(impressions) == 2 + + found1 = False + found2 = False + for impression in impressions: + if impression[1] == 'SPLIT_1': + found1 = True + if impression[1] == 'SPLIT_2': + found2 = True + assert found1 + assert found2 + factory.destroy() + + def test_impression_toggle_debug(self, mocker): + """Test get_treatment execution paths.""" + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + TelemetrySubmitterMock(), + ) + + factory.block_until_ready(5) + + split_storage.update([ + from_raw(splits_json['splitChange1_1']['ff']['d'][0]), + from_raw(splits_json['splitChange1_1']['ff']['d'][1]), + from_raw(splits_json['splitChange1_1']['ff']['d'][2]) + ], [], -1) + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) + assert client.get_treatment('some_key', 'SPLIT_1') == 'off' + assert client.get_treatment('some_key', 'SPLIT_2') == 'on' + assert client.get_treatment('some_key', 'SPLIT_3') == 'on' + + impressions = impression_storage.pop_many(100) + assert len(impressions) == 2 + + found1 = False + found2 = False + for impression in impressions: + if impression[1] == 'SPLIT_1': + found1 = True + if impression[1] == 'SPLIT_2': + found2 = True + assert found1 + assert found2 + factory.destroy() + + def test_impression_toggle_none(self, mocker): + """Test get_treatment execution paths.""" + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + non_strategy = StrategyNoneMode() + impmanager = ImpressionManager(non_strategy, non_strategy, telemetry_runtime_producer) + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + TelemetrySubmitterMock(), + ) + + factory.block_until_ready(5) + + split_storage.update([ + from_raw(splits_json['splitChange1_1']['ff']['d'][0]), + from_raw(splits_json['splitChange1_1']['ff']['d'][1]), + from_raw(splits_json['splitChange1_1']['ff']['d'][2]) + ], [], -1) + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) + assert client.get_treatment('some_key', 'SPLIT_1') == 'off' + assert client.get_treatment('some_key', 'SPLIT_2') == 'on' + assert client.get_treatment('some_key', 'SPLIT_3') == 'on' + + impressions = impression_storage.pop_many(100) + assert len(impressions) == 0 + factory.destroy() + + @mock.patch('splitio.client.factory.SplitFactory.destroy') def test_destroy(self, mocker): """Test that destroy/destroyed calls are forwarded to the factory.""" split_storage = mocker.Mock(spec=SplitStorage) segment_storage = mocker.Mock(spec=SegmentStorage) + rb_segment_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) impression_storage = mocker.Mock(spec=ImpressionStorage) event_storage = mocker.Mock(spec=EventStorage) - def _get_storage_mock(name): - return { - 'splits': split_storage, - 'segments': segment_storage, - 'impressions': impression_storage, - 'events': event_storage, - }[name] - factory = mocker.Mock(spec=SplitFactory) - destroyed_mock = mocker.PropertyMock() - type(factory).destroyed = destroyed_mock - impmanager = mocker.Mock(spec=ImpressionManager) - recorder = StandardRecorder(impmanager, event_storage, impression_storage) - client = Client(factory, recorder, True) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + events_queue = queue.Queue() + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) client.destroy() - assert factory.destroy.mock_calls == [mocker.call()] assert client.destroyed is not None - assert destroyed_mock.mock_calls == [mocker.call()] + assert(mocker.called) def test_track(self, mocker): """Test that destroy/destroyed calls are forwarded to the factory.""" split_storage = mocker.Mock(spec=SplitStorage) segment_storage = mocker.Mock(spec=SegmentStorage) + rb_segment_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) impression_storage = mocker.Mock(spec=ImpressionStorage) event_storage = mocker.Mock(spec=EventStorage) event_storage.put.return_value = True - def _get_storage_mock(name): - return { - 'splits': split_storage, - 'segments': segment_storage, - 'impressions': impression_storage, - 'events': event_storage, - }[name] - factory = mocker.Mock(spec=SplitFactory) - factory._get_storage = _get_storage_mock + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + events_queue = queue.Queue() + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + destroyed_mock = mocker.PropertyMock() destroyed_mock.return_value = False - factory._waiting_fork.return_value = False - type(factory).destroyed = destroyed_mock factory._apikey = 'test' mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) - impmanager = mocker.Mock(spec=ImpressionManager) - recorder = StandardRecorder(impmanager, event_storage, impression_storage) - client = Client(factory, recorder, True) + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) assert client.track('key', 'user', 'purchase', 12) is True assert mocker.call([ EventWrapper( @@ -367,28 +971,60 @@ def _get_storage_mock(name): size=1024 ) ]) in event_storage.put.mock_calls + factory.destroy() def test_evaluations_before_running_post_fork(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0])], [], -1) destroyed_property = mocker.PropertyMock() destroyed_property.return_value = False - factory = mocker.Mock(spec=SplitFactory) - factory._waiting_fork.return_value = True - type(factory).destroyed = destroyed_property + impmanager = mocker.Mock(spec=ImpressionManager) + recorder = StandardRecorder(impmanager, mocker.Mock(), impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + events_queue = queue.Queue() + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': mocker.Mock()}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock(), + True + ) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() expected_msg = [ mocker.call('Client is not ready - no calls possible') ] - client = Client(factory, mocker.Mock()) + client = Client(factory, mocker.Mock(), mocker.Mock(), mocker.Mock(), FallbackTreatmentCalculator(None)) _logger = mocker.Mock() mocker.patch('splitio.client.client._LOGGER', new=_logger) - assert client.get_treatment('some_key', 'some_feature') == CONTROL + assert client.get_treatment('some_key', 'SPLIT_2') == CONTROL assert _logger.error.mock_calls == expected_msg _logger.reset_mock() - assert client.get_treatment_with_config('some_key', 'some_feature') == (CONTROL, None) + assert client.get_treatment_with_config('some_key', 'SPLIT_2') == (CONTROL, None) assert _logger.error.mock_calls == expected_msg _logger.reset_mock() @@ -396,10 +1032,2445 @@ def test_evaluations_before_running_post_fork(self, mocker): assert _logger.error.mock_calls == expected_msg _logger.reset_mock() - assert client.get_treatments(None, ['some_feature']) == {'some_feature': CONTROL} + assert client.get_treatments(None, ['SPLIT_2']) == {'SPLIT_2': CONTROL} + assert _logger.error.mock_calls == expected_msg + _logger.reset_mock() + + assert client.get_treatments_by_flag_set(None, 'set_1') == {'SPLIT_2': CONTROL} + assert _logger.error.mock_calls == expected_msg + _logger.reset_mock() + + assert client.get_treatments_by_flag_sets(None, ['set_1']) == {'SPLIT_2': CONTROL} + assert _logger.error.mock_calls == expected_msg + _logger.reset_mock() + + assert client.get_treatments_with_config('some_key', ['SPLIT_2']) == {'SPLIT_2': (CONTROL, None)} + assert _logger.error.mock_calls == expected_msg + _logger.reset_mock() + + assert client.get_treatments_with_config_by_flag_set('some_key', 'set_1') == {'SPLIT_2': (CONTROL, None)} assert _logger.error.mock_calls == expected_msg _logger.reset_mock() - assert client.get_treatments_with_config('some_key', ['some_feature']) == {'some_feature': (CONTROL, None)} + assert client.get_treatments_with_config_by_flag_sets('some_key', ['set_1']) == {'SPLIT_2': (CONTROL, None)} assert _logger.error.mock_calls == expected_msg _logger.reset_mock() + factory.destroy() + + @mock.patch('splitio.client.client.Client.ready', side_effect=None) + def test_telemetry_not_ready(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0])], [], -1) + recorder = StandardRecorder(impmanager, mocker.Mock(), mocker.Mock(), telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory('localhost', + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': mocker.Mock()}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + client = Client(factory, mocker.Mock(), mocker.Mock(), mocker.Mock(), FallbackTreatmentCalculator(None)) + client.ready = False + assert client.get_treatment('some_key', 'SPLIT_2') == CONTROL + assert(telemetry_storage._tel_config._not_ready == 1) + client.track('key', 'tt', 'ev') + assert(telemetry_storage._tel_config._not_ready == 2) + factory.destroy() + + def test_telemetry_record_treatment_exception(self, mocker): + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0])], [], -1) + segment_storage = mocker.Mock(spec=SegmentStorage) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory('localhost', + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + class SyncManagerMock(): + def stop(*_): + pass + factory._sync_manager = SyncManagerMock() + + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(factory).ready = ready_property + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) + def _raise(*_): + raise RuntimeError('something') + client._evaluator.eval_many_with_context = _raise + client._evaluator.eval_with_context = _raise + + + try: + client.get_treatment('key', 'SPLIT_2') + except: + pass + assert(telemetry_storage._method_exceptions._treatment == 1) + try: + client.get_treatment_with_config('key', 'SPLIT_2') + except: + pass + assert(telemetry_storage._method_exceptions._treatment_with_config == 1) + + try: + client.get_treatments('key', ['SPLIT_2']) + except: + pass + assert(telemetry_storage._method_exceptions._treatments == 1) + + try: + client.get_treatments_by_flag_set('key', 'set_1') + except: + pass + assert(telemetry_storage._method_exceptions._treatments_by_flag_set == 1) + + try: + client.get_treatments_by_flag_sets('key', ['set_1']) + except: + pass + assert(telemetry_storage._method_exceptions._treatments_by_flag_sets == 1) + + try: + client.get_treatments_with_config('key', ['SPLIT_2']) + except: + pass + assert(telemetry_storage._method_exceptions._treatments_with_config == 1) + + try: + client.get_treatments_with_config_by_flag_set('key', 'set_1') + except: + pass + assert(telemetry_storage._method_exceptions._treatments_with_config_by_flag_set == 1) + + try: + client.get_treatments_with_config_by_flag_sets('key', ['set_1']) + except: + pass + assert(telemetry_storage._method_exceptions._treatments_with_config_by_flag_sets == 1) + factory.destroy() + + def test_telemetry_method_latency(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0])], [], -1) + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda:1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + def stop(*_): + pass + factory._sync_manager.stop = stop + + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) + assert client.get_treatment('key', 'SPLIT_2') == 'on' + assert(telemetry_storage._method_latencies._treatment[0] == 1) + + client.get_treatment_with_config('key', 'SPLIT_2') + assert(telemetry_storage._method_latencies._treatment_with_config[0] == 1) + + client.get_treatments('key', ['SPLIT_2']) + assert(telemetry_storage._method_latencies._treatments[0] == 1) + + client.get_treatments_by_flag_set('key', 'set_1') + assert(telemetry_storage._method_latencies._treatments_by_flag_set[0] == 1) + + client.get_treatments_by_flag_sets('key', ['set_1']) + assert(telemetry_storage._method_latencies._treatments_by_flag_sets[0] == 1) + + client.get_treatments_with_config('key', ['SPLIT_2']) + assert(telemetry_storage._method_latencies._treatments_with_config[0] == 1) + + client.get_treatments_with_config_by_flag_set('key', 'set_1') + assert(telemetry_storage._method_latencies._treatments_with_config_by_flag_set[0] == 1) + + client.get_treatments_with_config_by_flag_sets('key', ['set_1']) + assert(telemetry_storage._method_latencies._treatments_with_config_by_flag_sets[0] == 1) + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + client.track('key', 'tt', 'ev') + assert(telemetry_storage._method_latencies._track[0] == 1) + factory.destroy() + + @mock.patch('splitio.recorder.recorder.StandardRecorder.record_track_stats', side_effect=Exception()) + def test_telemetry_track_exception(self, mocker): + split_storage = mocker.Mock(spec=SplitStorage) + segment_storage = mocker.Mock(spec=SegmentStorage) + rb_segment_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + events_queue = queue.Queue() + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) + try: + client.track('key', 'tt', 'ev') + except: + pass + assert(telemetry_storage._method_exceptions._track == 1) + factory.destroy() + + def test_impressions_properties(self, mocker): + """Test get_treatment execution paths.""" + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer(), + unique_keys_tracker=UniqueKeysTracker(), + imp_counter=ImpressionsCounter()) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + TelemetrySubmitterMock(), + ) + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(factory).ready = ready_property + factory.block_until_ready(5) + + split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0])], [], -1) + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock(spec=Evaluator) + evaluation = { + 'treatment': 'on', + 'configurations': None, + 'impression': { + 'label': 'some_label', + 'change_number': 123 + }, + 'impressions_disabled': False + } + client._evaluator.eval_with_context.return_value = evaluation + client._evaluator.eval_many_with_context.return_value = { + 'SPLIT_2': evaluation + } + + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + assert client.get_treatment('some_key', 'SPLIT_2', evaluation_options=EvaluationOptions({"prop": "value"})) == 'on' + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + assert client.get_treatment('some_key', 'SPLIT_2', evaluation_options=EvaluationOptions(12)) == 'on' + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None)] + assert _logger.error.mock_calls == [mocker.call('%s: properties must be of type dictionary.', 'get_treatment')] + + _logger.reset_mock() + assert client.get_treatment('some_key', 'SPLIT_2', evaluation_options=EvaluationOptions('12')) == 'on' + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, 1000, None)] + assert _logger.error.mock_calls == [mocker.call('%s: properties must be of type dictionary.', 'get_treatment')] + + assert client.get_treatment_with_config('some_key', 'SPLIT_2', evaluation_options=EvaluationOptions({"prop": "value"})) == ('on', None) + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + assert client.get_treatments('some_key', ['SPLIT_2'], evaluation_options=EvaluationOptions({"prop": "value"})) == {'SPLIT_2': 'on'} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + _logger.reset_mock() + assert client.get_treatments('some_key', ['SPLIT_2'], evaluation_options=EvaluationOptions("prop")) == {'SPLIT_2': 'on'} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, 1000, None)] + assert _logger.error.mock_calls == [mocker.call('%s: properties must be of type dictionary.', 'get_treatments')] + + _logger.reset_mock() + assert client.get_treatments('some_key', ['SPLIT_2'], evaluation_options=EvaluationOptions(123)) == {'SPLIT_2': 'on'} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, 1000, None)] + assert _logger.error.mock_calls == [mocker.call('%s: properties must be of type dictionary.', 'get_treatments')] + + _logger.reset_mock() + assert client.get_treatments('some_key', ['SPLIT_2'], evaluation_options=123) == {'SPLIT_2': 'on'} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, 1000, None)] + assert _logger.error.mock_calls == [mocker.call('%s: evaluation options should be an instance of EvaluationOptions. Setting its value to None.', 'get_treatments')] + + assert client.get_treatments_with_config('some_key', ['SPLIT_2'], evaluation_options=EvaluationOptions({"prop": "value"})) == {'SPLIT_2': ('on', None)} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + assert client.get_treatments_by_flag_set('some_key', 'set_1', evaluation_options=EvaluationOptions({"prop": "value"})) == {'SPLIT_2': 'on'} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + assert client.get_treatments_by_flag_sets('some_key', ['set_1'], evaluation_options=EvaluationOptions({"prop": "value"})) == {'SPLIT_2': 'on'} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + assert client.get_treatments_with_config_by_flag_set('some_key', 'set_1', evaluation_options=EvaluationOptions({"prop": "value"})) == {'SPLIT_2': ('on', None)} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + assert client.get_treatments_with_config_by_flag_sets('some_key', ['set_1'], evaluation_options=EvaluationOptions({"prop": "value"})) == {'SPLIT_2': ('on', None)} + assert impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + @mock.patch('splitio.engine.evaluator.Evaluator.eval_with_context', side_effect=RuntimeError()) + def test_fallback_treatment_eval_exception(self, mocker): + # using fallback when the evaluator has RuntimeError exception + split_storage = mocker.Mock(spec=SplitStorage) + segment_storage = mocker.Mock(spec=SegmentStorage) + rb_segment_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + impmanager = ImpressionManager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_producer.get_telemetry_runtime_producer()) + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + internal_events_queue = queue.Queue() + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + + self.imps = None + def put(impressions): + self.imps = impressions + impression_storage.put = put + + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global", '{"prop": "val"}')))) + + def get_feature_flag_names_by_flag_sets(*_): + return ["some", "some2"] + client._get_feature_flag_names_by_flag_sets = get_feature_flag_names_by_flag_sets + + treatment = client.get_treatment("key", "some") + assert(treatment == "on-global") + assert(self.imps[0].treatment == "on-global") + assert(self.imps[0].label == "fallback - exception") + + self.imps = None + treatment = client.get_treatments("key_m", ["some", "some2"]) + assert(treatment == {"some": "on-global", "some2": "on-global"}) + assert(self.imps[0].treatment == "on-global") + assert(self.imps[0].label == "fallback - exception") + assert(self.imps[1].treatment == "on-global") + assert(self.imps[1].label == "fallback - exception") + + assert(client.get_treatment_with_config("key", "some") == ("on-global", '{"prop": "val"}')) + assert(client.get_treatments_with_config("key_m", ["some", "some2"]) == {"some": ("on-global", '{"prop": "val"}'), "some2": ("on-global", '{"prop": "val"}')}) + assert(client.get_treatments_by_flag_set("key_m", "set") == {"some": "on-global", "some2": "on-global"}) + assert(client.get_treatments_by_flag_set("key_m", ["set"]) == {"some": "on-global", "some2": "on-global"}) + assert(client.get_treatments_with_config_by_flag_set("key_m", "set") == {"some": ("on-global", '{"prop": "val"}'), "some2": ("on-global", '{"prop": "val"}')}) + assert(client.get_treatments_with_config_by_flag_sets("key_m", ["set"]) == {"some": ("on-global", '{"prop": "val"}'), "some2": ("on-global", '{"prop": "val"}')}) + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global", '{"prop": "val"}'), {'some': FallbackTreatment("on-local")})) + treatment = client.get_treatment("key2", "some") + assert(treatment == "on-local") + assert(self.imps[0].treatment == "on-local") + assert(self.imps[0].label == "fallback - exception") + + self.imps = None + treatment = client.get_treatments("key2_m", ["some", "some2"]) + assert(treatment == {"some": "on-local", "some2": "on-global"}) + assert_both = 0 + for imp in self.imps: + if imp.feature_name == "some": + assert_both += 1 + assert(imp.treatment == "on-local") + assert(imp.label == "fallback - exception") + else: + assert_both += 1 + assert(imp.treatment == "on-global") + assert(imp.label == "fallback - exception") + assert assert_both == 2 + + assert(client.get_treatment_with_config("key", "some") == ("on-local", None)) + assert(client.get_treatments_with_config("key_m", ["some", "some2"]) == {"some": ("on-local", None), "some2": ("on-global", '{"prop": "val"}')}) + assert(client.get_treatments_by_flag_set("key_m", "set") == {"some": "on-local", "some2": "on-global"}) + assert(client.get_treatments_by_flag_set("key_m", ["set"]) == {"some": "on-local", "some2": "on-global"}) + assert(client.get_treatments_with_config_by_flag_set("key_m", "set") == {"some": ("on-local", None), "some2": ("on-global", '{"prop": "val"}')}) + assert(client.get_treatments_with_config_by_flag_sets("key_m", ["set"]) == {"some": ("on-local", None), "some2": ("on-global", '{"prop": "val"}')}) + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'some': FallbackTreatment("on-local", '{"prop": "val"}')})) + treatment = client.get_treatment("key3", "some") + assert(treatment == "on-local") + assert(self.imps[0].treatment == "on-local") + assert(self.imps[0].label == "fallback - exception") + + self.imps = None + treatment = client.get_treatments("key3_m", ["some", "some2"]) + assert(treatment == {"some": "on-local", "some2": "control"}) + assert_both = 0 + for imp in self.imps: + if imp.feature_name == "some": + assert_both += 1 + assert(imp.treatment == "on-local") + assert(imp.label == "fallback - exception") + else: + assert_both += 1 + assert(imp.treatment == "control") + assert(imp.label == "exception") + assert assert_both == 2 + + assert(client.get_treatment_with_config("key", "some") == ("on-local", '{"prop": "val"}')) + assert(client.get_treatments_with_config("key_m", ["some", "some2"]) == {"some": ("on-local", '{"prop": "val"}'), "some2": ("control", None)}) + assert(client.get_treatments_by_flag_set("key_m", "set") == {"some": "on-local", "some2": "control"}) + assert(client.get_treatments_by_flag_set("key_m", ["set"]) == {"some": "on-local", "some2": "control"}) + assert(client.get_treatments_with_config_by_flag_set("key_m", "set") == {"some": ("on-local", '{"prop": "val"}'), "some2": ("control", None)}) + assert(client.get_treatments_with_config_by_flag_sets("key_m", ["set"]) == {"some": ("on-local", '{"prop": "val"}'), "some2": ("control", None)}) + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'some2': FallbackTreatment("on-local")})) + treatment = client.get_treatment("key4", "some") + assert(treatment == "control") + assert(self.imps[0].treatment == "control") + assert(self.imps[0].label == "exception") + + try: + factory.destroy() + except: + pass + + @mock.patch('splitio.engine.evaluator.Evaluator.eval_with_context', side_effect=Exception()) + def test_fallback_treatment_exception(self, mocker): + # using fallback when the evaluator has RuntimeError exception + split_storage = mocker.Mock(spec=SplitStorage) + segment_storage = mocker.Mock(spec=SegmentStorage) + rb_segment_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + impmanager = ImpressionManager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_producer.get_telemetry_runtime_producer()) + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + internal_events_queue = queue.Queue() + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + + self.imps = None + def put(impressions): + self.imps = impressions + impression_storage.put = put + + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global")))) + treatment = client.get_treatment("key", "some") + assert(treatment == "on-global") + assert(self.imps == None) + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global"), {'some': FallbackTreatment("on-local")})) + treatment = client.get_treatment("key2", "some") + assert(treatment == "on-local") + assert(self.imps == None) + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'some': FallbackTreatment("on-local")})) + treatment = client.get_treatment("key3", "some") + assert(treatment == "on-local") + assert(self.imps == None) + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'some2': FallbackTreatment("on-local")})) + treatment = client.get_treatment("key4", "some") + assert(treatment == "control") + assert(self.imps == None) + + try: + factory.destroy() + except: + pass + + @mock.patch('splitio.client.client.Client.ready', side_effect=None) + def test_fallback_treatment_not_ready_impressions(self, mocker): + # using fallback when the evaluator has RuntimeError exception + split_storage = mocker.Mock(spec=SplitStorage) + segment_storage = mocker.Mock(spec=SegmentStorage) + rb_segment_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + impmanager = ImpressionManager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_producer.get_telemetry_runtime_producer()) + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + internal_events_queue = queue.Queue() + factory = SplitFactory(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + + self.imps = None + def put(impressions): + self.imps = impressions + impression_storage.put = put + + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + client = Client(factory, recorder, mocker.Mock(), True, FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global")))) + client.ready = False + + treatment = client.get_treatment("key", "some") + assert(self.imps[0].treatment == "on-global") + assert(self.imps[0].label == "fallback - not ready") + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global"), {'some': FallbackTreatment("on-local")})) + treatment = client.get_treatment("key2", "some") + assert(treatment == "on-local") + assert(self.imps[0].treatment == "on-local") + assert(self.imps[0].label == "fallback - not ready") + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'some': FallbackTreatment("on-local")})) + treatment = client.get_treatment("key3", "some") + assert(treatment == "on-local") + assert(self.imps[0].treatment == "on-local") + assert(self.imps[0].label == "fallback - not ready") + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'some2': FallbackTreatment("on-local")})) + treatment = client.get_treatment("key4", "some") + assert(treatment == "control") + assert(self.imps[0].treatment == "control") + assert(self.imps[0].label == "not ready") + + try: + factory.destroy() + except: + pass + + def test_events_subscription(self, mocker): + events_manager = mocker.Mock(spec=EventsManager) + client = Client(mocker.Mock(), mocker.Mock(), events_manager, True, FallbackTreatmentCalculator(None)) + client.on(SdkEvent.SDK_READY, self.test_fallback_treatment_not_ready_impressions) + assert events_manager.register.mock_calls[0] == mock.call(SdkEvent.SDK_READY, self.test_fallback_treatment_not_ready_impressions) + + events_manager.register.mock_calls = [] + client.on("dd", self.test_fallback_treatment_not_ready_impressions) + assert events_manager.register.mock_calls == [] + + client.on(SdkEvent.SDK_READY, "qwe") + assert events_manager.register.mock_calls == [] + +class ClientAsyncTests(object): # pylint: disable=too-few-public-methods + """Split client async test cases.""" + + @pytest.mark.asyncio + async def test_get_treatment_async(self, mocker): + """Test get_treatment_async execution paths.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + await split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0])], [], -1) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + TelemetrySubmitterMock(), + ) + + await factory.block_until_ready(1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock(spec=Evaluator) + client._evaluator.eval_with_context.return_value = { + 'treatment': 'on', + 'configurations': None, + 'impression': { + 'label': 'some_label', + 'change_number': 123 + }, + 'impressions_disabled': False + } + _logger = mocker.Mock() + assert await client.get_treatment('some_key', 'SPLIT_2') == 'on' + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None)] + assert _logger.mock_calls == [] + + # Test with client not ready + ready_property = mocker.PropertyMock() + ready_property.return_value = False + type(factory).ready = ready_property + assert await client.get_treatment('some_key', 'SPLIT_2', {'some_attribute': 1}) == 'control' + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'control', Label.NOT_READY, None, None, 1000, None, None)] + + # Test with exception: + ready_property.return_value = True + def _raise(*_): + raise RuntimeError('something') + client._evaluator.eval_with_context.side_effect = _raise + assert await client.get_treatment('some_key', 'SPLIT_2') == 'control' + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'control', 'exception', None, None, 1000, None, None)] + await factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatment_with_config_async(self, mocker): + """Test get_treatment execution paths.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + await split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0])], [], -1) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + await factory.block_until_ready(1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock(spec=Evaluator) + client._evaluator.eval_with_context.return_value = { + 'treatment': 'on', + 'configurations': '{"some_config": True}', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + }, + 'impressions_disabled': False + } + _logger = mocker.Mock() + client._send_impression_to_listener = mocker.Mock() + assert await client.get_treatment_with_config( + 'some_key', + 'SPLIT_2' + ) == ('on', '{"some_config": True}') + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None)] + assert _logger.mock_calls == [] + + # Test with client not ready + ready_property = mocker.PropertyMock() + ready_property.return_value = False + type(factory).ready = ready_property + assert await client.get_treatment_with_config('some_key', 'SPLIT_2', {'some_attribute': 1}) == ('control', None) + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)] + + # Test with exception: + ready_property.return_value = True + + def _raise(*_): + raise RuntimeError('something') + client._evaluator.eval_with_context.side_effect = _raise + assert await client.get_treatment_with_config('some_key', 'SPLIT_2') == ('control', None) + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'control', 'exception', None, None, 1000, None, None)] + await factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_async(self, mocker): + """Test get_treatment execution paths.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + await split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0]), from_raw(splits_json['splitChange1_1']['ff']['d'][1])], [], -1) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + await factory.block_until_ready(1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock(spec=Evaluator) + evaluation = { + 'treatment': 'on', + 'configurations': '{"color": "red"}', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + }, + 'impressions_disabled': False + } + client._evaluator.eval_many_with_context.return_value = { + 'SPLIT_2': evaluation, + 'SPLIT_1': evaluation + } + _logger = mocker.Mock() + client._send_impression_to_listener = mocker.Mock() + assert await client.get_treatments('key', ['SPLIT_2', 'SPLIT_1']) == {'SPLIT_2': 'on', 'SPLIT_1': 'on'} + + impressions_called = await impression_storage.pop_many(100) + assert Impression('key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert Impression('key', 'SPLIT_1', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert _logger.mock_calls == [] + + # Test with client not ready + ready_property = mocker.PropertyMock() + ready_property.return_value = False + type(factory).ready = ready_property + assert await client.get_treatments('some_key', ['SPLIT_2'], {'some_attribute': 1}) == {'SPLIT_2': 'control'} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)] + + # Test with exception: + ready_property.return_value = True + + def _raise(*_): + raise RuntimeError('something') + client._evaluator.eval_many_with_context.side_effect = _raise + assert await client.get_treatments('key', ['SPLIT_2', 'SPLIT_1']) == {'SPLIT_2': 'control', 'SPLIT_1': 'control'} + await factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_set_async(self, mocker): + """Test get_treatment execution paths.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + await split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0]), from_raw(splits_json['splitChange1_1']['ff']['d'][1])], [], -1) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + await factory.block_until_ready(1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock(spec=Evaluator) + evaluation = { + 'treatment': 'on', + 'configurations': '{"color": "red"}', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + }, + 'impressions_disabled': False + } + client._evaluator.eval_many_with_context.return_value = { + 'SPLIT_2': evaluation, + 'SPLIT_1': evaluation + } + _logger = mocker.Mock() + client._send_impression_to_listener = mocker.Mock() + assert await client.get_treatments_by_flag_set('key', 'set_1') == {'SPLIT_2': 'on', 'SPLIT_1': 'on'} + + impressions_called = await impression_storage.pop_many(100) + assert Impression('key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert Impression('key', 'SPLIT_1', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert _logger.mock_calls == [] + + # Test with client not ready + ready_property = mocker.PropertyMock() + ready_property.return_value = False + type(factory).ready = ready_property + assert await client.get_treatments_by_flag_set('some_key', 'set_2', {'some_attribute': 1}) == {'SPLIT_1': 'control'} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_1', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)] + + # Test with exception: + ready_property.return_value = True + + def _raise(*_): + raise RuntimeError('something') + client._evaluator.eval_many_with_context.side_effect = _raise + assert await client.get_treatments_by_flag_set('key', 'set_1') == {'SPLIT_2': 'control', 'SPLIT_1': 'control'} + await factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_sets_async(self, mocker): + """Test get_treatment execution paths.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + await split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0]), from_raw(splits_json['splitChange1_1']['ff']['d'][1])], [], -1) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + await factory.block_until_ready(1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock(spec=Evaluator) + evaluation = { + 'treatment': 'on', + 'configurations': '{"color": "red"}', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + }, + 'impressions_disabled': False + } + client._evaluator.eval_many_with_context.return_value = { + 'SPLIT_2': evaluation, + 'SPLIT_1': evaluation + } + _logger = mocker.Mock() + client._send_impression_to_listener = mocker.Mock() + assert await client.get_treatments_by_flag_sets('key', ['set_1']) == {'SPLIT_2': 'on', 'SPLIT_1': 'on'} + + impressions_called = await impression_storage.pop_many(100) + assert Impression('key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert Impression('key', 'SPLIT_1', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert _logger.mock_calls == [] + + # Test with client not ready + ready_property = mocker.PropertyMock() + ready_property.return_value = False + type(factory).ready = ready_property + assert await client.get_treatments_by_flag_sets('some_key', ['set_2'], {'some_attribute': 1}) == {'SPLIT_1': 'control'} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_1', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)] + + # Test with exception: + ready_property.return_value = True + + def _raise(*_): + raise RuntimeError('something') + client._evaluator.eval_many_with_context.side_effect = _raise + assert await client.get_treatments_by_flag_sets('key', ['set_1']) == {'SPLIT_2': 'control', 'SPLIT_1': 'control'} + await factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_with_config(self, mocker): + """Test get_treatment execution paths.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + await split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0]), from_raw(splits_json['splitChange1_1']['ff']['d'][1])], [], -1) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + await factory.block_until_ready(1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock(spec=Evaluator) + evaluation = { + 'treatment': 'on', + 'configurations': '{"color": "red"}', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + }, + 'impressions_disabled': False + } + client._evaluator.eval_many_with_context.return_value = { + 'SPLIT_1': evaluation, + 'SPLIT_2': evaluation + } + _logger = mocker.Mock() + assert await client.get_treatments_with_config('key', ['SPLIT_1', 'SPLIT_2']) == { + 'SPLIT_1': ('on', '{"color": "red"}'), + 'SPLIT_2': ('on', '{"color": "red"}') + } + + impressions_called = await impression_storage.pop_many(100) + assert Impression('key', 'SPLIT_1', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert Impression('key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert _logger.mock_calls == [] + + # Test with client not ready + ready_property = mocker.PropertyMock() + ready_property.return_value = False + type(factory).ready = ready_property + assert await client.get_treatments_with_config('some_key', ['SPLIT_1'], {'some_attribute': 1}) == {'SPLIT_1': ('control', None)} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_1', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)] + + # Test with exception: + ready_property.return_value = True + + def _raise(*_): + raise RuntimeError('something') + client._evaluator.eval_many_with_context.side_effect = _raise + assert await client.get_treatments_with_config('key', ['SPLIT_1', 'SPLIT_2']) == { + 'SPLIT_1': ('control', None), + 'SPLIT_2': ('control', None) + } + await factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_set(self, mocker): + """Test get_treatment execution paths.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + await split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0]), from_raw(splits_json['splitChange1_1']['ff']['d'][1])], [], -1) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + await factory.block_until_ready(1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock(spec=Evaluator) + evaluation = { + 'treatment': 'on', + 'configurations': '{"color": "red"}', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + }, + 'impressions_disabled': False + } + client._evaluator.eval_many_with_context.return_value = { + 'SPLIT_1': evaluation, + 'SPLIT_2': evaluation + } + _logger = mocker.Mock() + assert await client.get_treatments_with_config_by_flag_set('key', 'set_1') == { + 'SPLIT_1': ('on', '{"color": "red"}'), + 'SPLIT_2': ('on', '{"color": "red"}') + } + + impressions_called = await impression_storage.pop_many(100) + assert Impression('key', 'SPLIT_1', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert Impression('key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert _logger.mock_calls == [] + + # Test with client not ready + ready_property = mocker.PropertyMock() + ready_property.return_value = False + type(factory).ready = ready_property + assert await client.get_treatments_with_config_by_flag_set('some_key', 'set_2', {'some_attribute': 1}) == {'SPLIT_1': ('control', None)} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_1', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)] + + # Test with exception: + ready_property.return_value = True + + def _raise(*_): + raise RuntimeError('something') + client._evaluator.eval_many_with_context.side_effect = _raise + assert await client.get_treatments_with_config_by_flag_set('key', 'set_1') == { + 'SPLIT_1': ('control', None), + 'SPLIT_2': ('control', None) + } + await factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_sets(self, mocker): + """Test get_treatment execution paths.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + await split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0]), from_raw(splits_json['splitChange1_1']['ff']['d'][1])], [], -1) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + await factory.block_until_ready(1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock(spec=Evaluator) + evaluation = { + 'treatment': 'on', + 'configurations': '{"color": "red"}', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + }, + 'impressions_disabled': False + } + client._evaluator.eval_many_with_context.return_value = { + 'SPLIT_1': evaluation, + 'SPLIT_2': evaluation + } + _logger = mocker.Mock() + assert await client.get_treatments_with_config_by_flag_sets('key', ['set_1']) == { + 'SPLIT_1': ('on', '{"color": "red"}'), + 'SPLIT_2': ('on', '{"color": "red"}') + } + + impressions_called = await impression_storage.pop_many(100) + assert Impression('key', 'SPLIT_1', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert Impression('key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None) in impressions_called + assert _logger.mock_calls == [] + + # Test with client not ready + ready_property = mocker.PropertyMock() + ready_property.return_value = False + type(factory).ready = ready_property + assert await client.get_treatments_with_config_by_flag_sets('some_key', ['set_2'], {'some_attribute': 1}) == {'SPLIT_1': ('control', None)} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_1', 'control', Label.NOT_READY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)] + + # Test with exception: + ready_property.return_value = True + + def _raise(*_): + raise RuntimeError('something') + client._evaluator.eval_many_with_context.side_effect = _raise + assert await client.get_treatments_with_config_by_flag_sets('key', ['set_1']) == { + 'SPLIT_1': ('control', None), + 'SPLIT_2': ('control', None) + } + await factory.destroy() + + @pytest.mark.asyncio + async def test_impression_toggle_optimized(self, mocker): + """Test get_treatment execution paths.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + impmanager = ImpressionManager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + + await factory.block_until_ready(5) + + await split_storage.update([ + from_raw(splits_json['splitChange1_1']['ff']['d'][0]), + from_raw(splits_json['splitChange1_1']['ff']['d'][1]), + from_raw(splits_json['splitChange1_1']['ff']['d'][2]) + ], [], -1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + treatment = await client.get_treatment('some_key', 'SPLIT_1') + assert treatment == 'off' + treatment = await client.get_treatment('some_key', 'SPLIT_2') + assert treatment == 'on' + treatment = await client.get_treatment('some_key', 'SPLIT_3') + assert treatment == 'on' + + impressions = await impression_storage.pop_many(100) + assert len(impressions) == 2 + + found1 = False + found2 = False + for impression in impressions: + if impression[1] == 'SPLIT_1': + found1 = True + if impression[1] == 'SPLIT_2': + found2 = True + assert found1 + assert found2 + await factory.destroy() + + @pytest.mark.asyncio + async def test_impression_toggle_debug(self, mocker): + """Test get_treatment execution paths.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + + await factory.block_until_ready(5) + + await split_storage.update([ + from_raw(splits_json['splitChange1_1']['ff']['d'][0]), + from_raw(splits_json['splitChange1_1']['ff']['d'][1]), + from_raw(splits_json['splitChange1_1']['ff']['d'][2]) + ], [], -1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + assert await client.get_treatment('some_key', 'SPLIT_1') == 'off' + assert await client.get_treatment('some_key', 'SPLIT_2') == 'on' + assert await client.get_treatment('some_key', 'SPLIT_3') == 'on' + + impressions = await impression_storage.pop_many(100) + assert len(impressions) == 2 + + found1 = False + found2 = False + for impression in impressions: + if impression[1] == 'SPLIT_1': + found1 = True + if impression[1] == 'SPLIT_2': + found2 = True + assert found1 + assert found2 + await factory.destroy() + + @pytest.mark.asyncio + async def test_impression_toggle_none(self, mocker): + """Test get_treatment execution paths.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + non_strategy = StrategyNoneMode() + impmanager = ImpressionManager(non_strategy, non_strategy, telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + + await factory.block_until_ready(5) + + await split_storage.update([ + from_raw(splits_json['splitChange1_1']['ff']['d'][0]), + from_raw(splits_json['splitChange1_1']['ff']['d'][1]), + from_raw(splits_json['splitChange1_1']['ff']['d'][2]) + ], [], -1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + assert await client.get_treatment('some_key', 'SPLIT_1') == 'off' + assert await client.get_treatment('some_key', 'SPLIT_2') == 'on' + assert await client.get_treatment('some_key', 'SPLIT_3') == 'on' + + impressions = await impression_storage.pop_many(100) + assert len(impressions) == 0 + await factory.destroy() + + @pytest.mark.asyncio + async def test_track_async(self, mocker): + """Test that destroy/destroyed calls are forwarded to the factory.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = mocker.Mock(spec=SegmentStorage) + rb_segment_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + self.events = [] + async def put(event): + self.events.append(event) + return True + event_storage.put = put + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + destroyed_mock = mocker.PropertyMock() + destroyed_mock.return_value = False + factory._apikey = 'test' + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + + await factory.block_until_ready(1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + assert await client.track('key', 'user', 'purchase', 12) is True + assert self.events[0] == [EventWrapper( + event=Event('key', 'user', 'purchase', 12, 1000, None), + size=1024 + )] + await factory.destroy() + + @pytest.mark.asyncio + async def test_telemetry_not_ready_async(self, mocker): + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = InMemoryEventStorageAsync(10, telemetry_runtime_producer) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + await split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0])], [], -1) + factory = SplitFactoryAsync('localhost', + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': mocker.Mock()}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + ready_property = mocker.PropertyMock() + ready_property.return_value = False + type(factory).ready = ready_property + + await factory.block_until_ready(1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + assert await client.get_treatment('some_key', 'SPLIT_2') == CONTROL + assert(telemetry_storage._tel_config._not_ready == 1) + await client.track('key', 'tt', 'ev') + assert(telemetry_storage._tel_config._not_ready == 2) + await factory.destroy() + + @pytest.mark.asyncio + async def test_telemetry_record_treatment_exception_async(self, mocker): + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = InMemoryEventStorageAsync(10, telemetry_runtime_producer) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + await split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0])], [], -1) + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(factory).ready = ready_property + + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock() + def _raise(*_): + raise RuntimeError('something') + client._evaluator.eval_with_context.side_effect = _raise + client._evaluator.eval_many_with_context.side_effect = _raise + + await client.get_treatment('key', 'SPLIT_2') + assert(telemetry_storage._method_exceptions._treatment == 1) + + await client.get_treatment_with_config('key', 'SPLIT_2') + assert(telemetry_storage._method_exceptions._treatment_with_config == 1) + + await client.get_treatments('key', ['SPLIT_2']) + assert(telemetry_storage._method_exceptions._treatments == 1) + + await client.get_treatments_by_flag_set('key', 'set_1') + assert(telemetry_storage._method_exceptions._treatments_by_flag_set == 1) + + await client.get_treatments_by_flag_sets('key', ['set_1']) + assert(telemetry_storage._method_exceptions._treatments_by_flag_sets == 1) + + await client.get_treatments_with_config('key', ['SPLIT_2']) + assert(telemetry_storage._method_exceptions._treatments_with_config == 1) + + await client.get_treatments_with_config_by_flag_set('key', 'set_1') + assert(telemetry_storage._method_exceptions._treatments_with_config_by_flag_set == 1) + + await client.get_treatments_with_config_by_flag_sets('key', ['set_1']) + assert(telemetry_storage._method_exceptions._treatments_with_config_by_flag_sets == 1) + + await factory.destroy() + + @pytest.mark.asyncio + async def test_telemetry_method_latency_async(self, mocker): + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = InMemoryEventStorageAsync(10, telemetry_runtime_producer) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + await split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0])], [], -1) + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(factory).ready = ready_property + + try: + await factory.block_until_ready(1) + except: + pass + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + assert await client.get_treatment('key', 'SPLIT_2') == 'on' + assert(telemetry_storage._method_latencies._treatment[0] == 1) + + await client.get_treatment_with_config('key', 'SPLIT_2') + assert(telemetry_storage._method_latencies._treatment_with_config[0] == 1) + + await client.get_treatments('key', ['SPLIT_2']) + assert(telemetry_storage._method_latencies._treatments[0] == 1) + + await client.get_treatments_by_flag_set('key', 'set_1') + assert(telemetry_storage._method_latencies._treatments_by_flag_set[0] == 1) + + await client.get_treatments_by_flag_sets('key', ['set_1']) + assert(telemetry_storage._method_latencies._treatments_by_flag_sets[0] == 1) + + await client.get_treatments_with_config('key', ['SPLIT_2']) + assert(telemetry_storage._method_latencies._treatments_with_config[0] == 1) + + await client.get_treatments_with_config_by_flag_set('key', 'set_1') + assert(telemetry_storage._method_latencies._treatments_with_config_by_flag_set[0] == 1) + + await client.get_treatments_with_config_by_flag_sets('key', ['set_1']) + assert(telemetry_storage._method_latencies._treatments_with_config_by_flag_sets[0] == 1) + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + await client.track('key', 'tt', 'ev') + assert(telemetry_storage._method_latencies._track[0] == 1) + await factory.destroy() + + @pytest.mark.asyncio + async def test_telemetry_track_exception_async(self, mocker): + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = mocker.Mock(spec=SegmentStorage) + rb_segment_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + event_storage = InMemoryEventStorageAsync(10, telemetry_producer.get_telemetry_runtime_producer()) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + async def exc(*_): + raise RuntimeError("something") + recorder.record_track_stats = exc + + await factory.block_until_ready(1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + try: + await client.track('key', 'tt', 'ev') + except: + pass + assert(telemetry_storage._method_exceptions._track == 1) + await factory.destroy() + + @pytest.mark.asyncio + async def test_impressions_properties_async(self, mocker): + """Test get_treatment_async execution paths.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorageAsync(10, telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer(), imp_counter=ImpressionsCounter()) + await split_storage.update([from_raw(splits_json['splitChange1_1']['ff']['d'][0])], [], -1) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + TelemetrySubmitterMock(), + ) + + await factory.block_until_ready(1) + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(None)) + client._evaluator = mocker.Mock(spec=Evaluator) + evaluation = { + 'treatment': 'on', + 'configurations': None, + 'impression': { + 'label': 'some_label', + 'change_number': 123 + }, + 'impressions_disabled': False + } + client._evaluator.eval_with_context.return_value = evaluation + client._evaluator.eval_many_with_context.return_value = { + 'SPLIT_2': evaluation + } + + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + assert await client.get_treatment('some_key', 'SPLIT_2', evaluation_options=EvaluationOptions({"prop": "value"})) == 'on' + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + assert await client.get_treatment('some_key', 'SPLIT_2', evaluation_options=EvaluationOptions(12)) == 'on' + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, None)] + assert _logger.error.mock_calls == [mocker.call('%s: properties must be of type dictionary.', 'get_treatment')] + + _logger.reset_mock() + assert await client.get_treatment('some_key', 'SPLIT_2', evaluation_options=EvaluationOptions('12')) == 'on' + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, 1000, None)] + assert _logger.error.mock_calls == [mocker.call('%s: properties must be of type dictionary.', 'get_treatment')] + + assert await client.get_treatment_with_config('some_key', 'SPLIT_2', evaluation_options=EvaluationOptions({"prop": "value"})) == ('on', None) + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + assert await client.get_treatments('some_key', ['SPLIT_2'], evaluation_options=EvaluationOptions({"prop": "value"})) == {'SPLIT_2': 'on'} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + _logger.reset_mock() + assert await client.get_treatments('some_key', ['SPLIT_2'], evaluation_options=EvaluationOptions("prop")) == {'SPLIT_2': 'on'} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, 1000, None)] + assert _logger.error.mock_calls == [mocker.call('%s: properties must be of type dictionary.', 'get_treatments')] + + _logger.reset_mock() + assert await client.get_treatments('some_key', ['SPLIT_2'], evaluation_options=EvaluationOptions(123)) == {'SPLIT_2': 'on'} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, 1000, None)] + assert _logger.error.mock_calls == [mocker.call('%s: properties must be of type dictionary.', 'get_treatments')] + + _logger.reset_mock() + assert await client.get_treatments('some_key', ['SPLIT_2'], evaluation_options=123) == {'SPLIT_2': 'on'} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, 1000, None)] + assert _logger.error.mock_calls == [mocker.call('%s: evaluation options should be an instance of EvaluationOptions. Setting its value to None.', 'get_treatments')] + + assert await client.get_treatments_with_config('some_key', ['SPLIT_2'], evaluation_options=EvaluationOptions({"prop": "value"})) == {'SPLIT_2': ('on', None)} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + assert await client.get_treatments_by_flag_set('some_key', 'set_1', evaluation_options=EvaluationOptions({"prop": "value"})) == {'SPLIT_2': 'on'} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + assert await client.get_treatments_by_flag_sets('some_key', ['set_1'], evaluation_options=EvaluationOptions({"prop": "value"})) == {'SPLIT_2': 'on'} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + assert await client.get_treatments_with_config_by_flag_set('some_key', 'set_1', evaluation_options=EvaluationOptions({"prop": "value"})) == {'SPLIT_2': ('on', None)} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + + assert await client.get_treatments_with_config_by_flag_sets('some_key', ['set_1'], evaluation_options=EvaluationOptions({"prop": "value"})) == {'SPLIT_2': ('on', None)} + assert await impression_storage.pop_many(100) == [Impression('some_key', 'SPLIT_2', 'on', 'some_label', 123, None, 1000, None, '{"prop": "value"}')] + try: + await factory.destroy() + except: + pass + + @pytest.mark.asyncio + async def test_fallback_treatment_eval_exception(self, mocker): + # using fallback when the evaluator has RuntimeError exception + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + split_storage = mocker.Mock(spec=SplitStorage) + segment_storage = mocker.Mock(spec=SegmentStorage) + rb_segment_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_producer.get_telemetry_runtime_producer()) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_evaluation_producer, telemetry_producer.get_telemetry_runtime_producer()) + + class TelemetrySubmitterMock(): + async def synchronize_config(*_): + pass + + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + impmanager, + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + TelemetrySubmitterMock(), + None + ) + + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(factory).ready = ready_property + + self.imps = None + async def put(impressions): + self.imps = impressions + impression_storage.put = put + + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global", '{"prop": "val"}')))) + + def eval_with_context(*_): + raise RuntimeError() + client._evaluator.eval_with_context = eval_with_context + + async def get_feature_flag_names_by_flag_sets(*_): + return ["some", "some2"] + client._get_feature_flag_names_by_flag_sets = get_feature_flag_names_by_flag_sets + + async def fetch_many(*_): + return {"some": from_raw(splits_json['splitChange1_1']['ff']['d'][0])} + split_storage.fetch_many = fetch_many + + async def fetch_many_rbs(*_): + return {} + rb_segment_storage.fetch_many = fetch_many_rbs + + treatment = await client.get_treatment("key", "some") + assert(treatment == "on-global") + assert(self.imps[0].treatment == "on-global") + assert(self.imps[0].label == "fallback - exception") + + self.imps = None + treatment = await client.get_treatments("key_m", ["some", "some2"]) + assert(treatment == {"some": "on-global", "some2": "on-global"}) + assert(self.imps[0].treatment == "on-global") + assert(self.imps[0].label == "fallback - exception") + assert(self.imps[1].treatment == "on-global") + assert(self.imps[1].label == "fallback - exception") + + assert(await client.get_treatment_with_config("key", "some") == ("on-global", '{"prop": "val"}')) + assert(await client.get_treatments_with_config("key_m", ["some", "some2"]) == {"some": ("on-global", '{"prop": "val"}'), "some2": ("on-global", '{"prop": "val"}')}) + assert(await client.get_treatments_by_flag_set("key_m", "set") == {"some": "on-global", "some2": "on-global"}) + assert(await client.get_treatments_by_flag_set("key_m", ["set"]) == {"some": "on-global", "some2": "on-global"}) + assert(await client.get_treatments_with_config_by_flag_set("key_m", "set") == {"some": ("on-global", '{"prop": "val"}'), "some2": ("on-global", '{"prop": "val"}')}) + assert(await client.get_treatments_with_config_by_flag_sets("key_m", ["set"]) == {"some": ("on-global", '{"prop": "val"}'), "some2": ("on-global", '{"prop": "val"}')}) + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global", '{"prop": "val"}'), {'some': FallbackTreatment("on-local")})) + treatment = await client.get_treatment("key2", "some") + assert(treatment == "on-local") + assert(self.imps[0].treatment == "on-local") + assert(self.imps[0].label == "fallback - exception") + + self.imps = None + treatment = await client.get_treatments("key2_m", ["some", "some2"]) + assert(treatment == {"some": "on-local", "some2": "on-global"}) + assert_both = 0 + for imp in self.imps: + if imp.feature_name == "some": + assert_both += 1 + assert(imp.treatment == "on-local") + assert(imp.label == "fallback - exception") + else: + assert_both += 1 + assert(imp.treatment == "on-global") + assert(imp.label == "fallback - exception") + assert assert_both == 2 + + assert(await client.get_treatment_with_config("key", "some") == ("on-local", None)) + assert(await client.get_treatments_with_config("key_m", ["some", "some2"]) == {"some": ("on-local", None), "some2": ("on-global", '{"prop": "val"}')}) + assert(await client.get_treatments_by_flag_set("key_m", "set") == {"some": "on-local", "some2": "on-global"}) + assert(await client.get_treatments_by_flag_set("key_m", ["set"]) == {"some": "on-local", "some2": "on-global"}) + assert(await client.get_treatments_with_config_by_flag_set("key_m", "set") == {"some": ("on-local", None), "some2": ("on-global", '{"prop": "val"}')}) + assert(await client.get_treatments_with_config_by_flag_sets("key_m", ["set"]) == {"some": ("on-local", None), "some2": ("on-global", '{"prop": "val"}')}) + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'some': FallbackTreatment("on-local", '{"prop": "val"}')})) + treatment = await client.get_treatment("key3", "some") + assert(treatment == "on-local") + assert(self.imps[0].treatment == "on-local") + assert(self.imps[0].label == "fallback - exception") + + self.imps = None + treatment = await client.get_treatments("key3_m", ["some", "some2"]) + assert(treatment == {"some": "on-local", "some2": "control"}) + assert_both = 0 + for imp in self.imps: + if imp.feature_name == "some": + assert_both += 1 + assert(imp.treatment == "on-local") + assert(imp.label == "fallback - exception") + else: + assert_both += 1 + assert(imp.treatment == "control") + assert(imp.label == "exception") + assert assert_both == 2 + + assert(await client.get_treatment_with_config("key", "some") == ("on-local", '{"prop": "val"}')) + assert(await client.get_treatments_with_config("key_m", ["some", "some2"]) == {"some": ("on-local", '{"prop": "val"}'), "some2": ("control", None)}) + assert(await client.get_treatments_by_flag_set("key_m", "set") == {"some": "on-local", "some2": "control"}) + assert(await client.get_treatments_by_flag_set("key_m", ["set"]) == {"some": "on-local", "some2": "control"}) + assert(await client.get_treatments_with_config_by_flag_set("key_m", "set") == {"some": ("on-local", '{"prop": "val"}'), "some2": ("control", None)}) + assert(await client.get_treatments_with_config_by_flag_sets("key_m", ["set"]) == {"some": ("on-local", '{"prop": "val"}'), "some2": ("control", None)}) + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'some2': FallbackTreatment("on-local")})) + treatment = await client.get_treatment("key4", "some") + assert(treatment == "control") + assert(self.imps[0].treatment == "control") + assert(self.imps[0].label == "exception") + + try: + await factory.destroy() + except: + pass + + @pytest.mark.asyncio + async def test_fallback_treatment_exception(self, mocker): + # using fallback when the evaluator has RuntimeError exception + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + split_storage = mocker.Mock(spec=SplitStorage) + segment_storage = mocker.Mock(spec=SegmentStorage) + rb_segment_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + impmanager = ImpressionManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_producer.get_telemetry_runtime_producer()) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + impmanager, + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock(), + None + ) + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(factory).ready = ready_property + + self.imps = None + async def put(impressions): + self.imps = impressions + impression_storage.put = put + + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global")))) + + def eval_with_context(*_): + raise Exception() + client._evaluator.eval_with_context = eval_with_context + + async def context_for(*_): + return EvaluationContext( + {}, + {}, + {} + ) + client._context_factory.context_for = context_for + + treatment = await client.get_treatment("key", "some") + assert(treatment == "on-global") + assert(self.imps[0].treatment == "on-global") + assert(self.imps[0].label == "fallback - exception") + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global"), {'some': FallbackTreatment("on-local")})) + treatment = await client.get_treatment("key2", "some") + assert(treatment == "on-local") + assert(self.imps[0].treatment == "on-local") + assert(self.imps[0].label == "fallback - exception") + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'some': FallbackTreatment("on-local")})) + treatment = await client.get_treatment("key3", "some") + assert(treatment == "on-local") + assert(self.imps[0].treatment == "on-local") + assert(self.imps[0].label == "fallback - exception") + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'some2': FallbackTreatment("on-local")})) + treatment = await client.get_treatment("key4", "some") + assert(treatment == "control") + assert(self.imps[0].treatment == "control") + assert(self.imps[0].label == "exception") + + try: + await factory.destroy() + except: + pass + + @pytest.mark.asyncio + async def test_fallback_treatment_not_ready_impressions(self, mocker): + # using fallback when the evaluator has RuntimeError exception + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + split_storage = mocker.Mock(spec=SplitStorage) + segment_storage = mocker.Mock(spec=SegmentStorage) + rb_segment_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + + mocker.patch('splitio.client.client.utctime_ms', new=lambda: 1000) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + impmanager = ImpressionManager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_producer.get_telemetry_runtime_producer()) + recorder = StandardRecorderAsync(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + async def manager_start_task(): + pass + + factory = SplitFactoryAsync(mocker.Mock(), + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + impmanager, + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock(), + manager_start_task + ) + + self.imps = None + async def put(impressions): + self.imps = impressions + impression_storage.put = put + + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + client = ClientAsync(factory, recorder, events_manager, True, FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global")))) + ready_property = mocker.PropertyMock() + ready_property.return_value = False + type(factory).ready = ready_property + + async def context_for(*_): + return EvaluationContext( + {"some": {}}, + {}, + {} + ) + client._context_factory.context_for = context_for + + treatment = await client.get_treatment("key", "some") + assert(self.imps[0].treatment == "on-global") + assert(self.imps[0].label == "fallback - not ready") + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global"), {'some': FallbackTreatment("on-local")})) + treatment = await client.get_treatment("key2", "some") + assert(treatment == "on-local") + assert(self.imps[0].treatment == "on-local") + assert(self.imps[0].label == "fallback - not ready") + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'some': FallbackTreatment("on-local")})) + treatment = await client.get_treatment("key3", "some") + assert(treatment == "on-local") + assert(self.imps[0].treatment == "on-local") + assert(self.imps[0].label == "fallback - not ready") + + self.imps = None + client._fallback_treatment_calculator = FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'some2': FallbackTreatment("on-local")})) + treatment = await client.get_treatment("key4", "some") + assert(treatment == "control") + assert(self.imps[0].treatment == "control") + assert(self.imps[0].label == "not ready") + + try: + await factory.destroy() + except: + pass + + @pytest.mark.asyncio + async def test_events_subscription(self, mocker): + events_manager = mocker.Mock(spec=EventsManagerAsync) + self.event = None + self.handle = None + async def register(sdk_event, callback_handle): + self.event = sdk_event + self.handle = callback_handle + events_manager.register = register + + client = ClientAsync(mocker.Mock(), mocker.Mock(), events_manager, True, FallbackTreatmentCalculator(None)) + await client.on(SdkEvent.SDK_READY, self.event_callback) + assert self.event == SdkEvent.SDK_READY + assert self.handle == self.event_callback + + self.event = None + await client.on("dd", self.event_callback) + assert self.event == None + + await client.on(SdkEvent.SDK_READY, "qwe") + assert self.event == None + + async def event_callback(self, metadata): + pass \ No newline at end of file diff --git a/tests/client/test_config.py b/tests/client/test_config.py index a52600bd..e08a1d4b 100644 --- a/tests/client/test_config.py +++ b/tests/client/test_config.py @@ -1,52 +1,125 @@ """Configuration unit tests.""" # pylint: disable=protected-access,no-self-use,line-too-long - +import pytest from splitio.client import config -from splitio.engine.impressions import ImpressionsMode - +from splitio.engine.impressions.impressions import ImpressionsMode +from splitio.models.fallback_treatment import FallbackTreatment +from splitio.models.fallback_config import FallbackTreatmentsConfiguration class ConfigSanitizationTests(object): """Inmemory storage-based integration tests.""" def test_parse_operation_mode(self): """Make sure operation mode is correctly captured.""" - assert config._parse_operation_mode('some', {}) == 'inmemory-standalone' - assert config._parse_operation_mode('localhost', {}) == 'localhost-standalone' - assert config._parse_operation_mode('some', {'redisHost': 'x'}) == 'redis-consumer' + assert (config._parse_operation_mode('some', {})) == ('standalone', 'memory') + assert (config._parse_operation_mode('localhost', {})) == ('localhost', 'localhost') + assert (config._parse_operation_mode('some', {'redisHost': 'x'})) == ('consumer', 'redis') + assert (config._parse_operation_mode('some', {'storageType': 'pluggable'})) == ('consumer', 'pluggable') + assert (config._parse_operation_mode('some', {'storageType': 'custom2'})) == ('standalone', 'memory') def test_sanitize_imp_mode(self): """Test sanitization of impressions mode.""" - mode, rate = config._sanitize_impressions_mode('OPTIMIZED', 1) + mode, rate = config._sanitize_impressions_mode('memory', 'OPTIMIZED', 1) assert mode == ImpressionsMode.OPTIMIZED assert rate == 60 - mode, rate = config._sanitize_impressions_mode('DEBUG', 1) + mode, rate = config._sanitize_impressions_mode('memory', 'DEBUG', 1) assert mode == ImpressionsMode.DEBUG assert rate == 1 - mode, rate = config._sanitize_impressions_mode('debug', 1) + mode, rate = config._sanitize_impressions_mode('redis', 'OPTIMIZED', 1) + assert mode == ImpressionsMode.OPTIMIZED + assert rate == 60 + + mode, rate = config._sanitize_impressions_mode('redis', 'debug', 1) assert mode == ImpressionsMode.DEBUG assert rate == 1 - mode, rate = config._sanitize_impressions_mode('ANYTHING', 200) + mode, rate = config._sanitize_impressions_mode('memory', 'ANYTHING', 200) + assert mode == ImpressionsMode.OPTIMIZED + assert rate == 200 + + mode, rate = config._sanitize_impressions_mode('pluggable', 'ANYTHING', 200) + assert mode == ImpressionsMode.OPTIMIZED + assert rate == 200 + + mode, rate = config._sanitize_impressions_mode('pluggable', 'NONE', 200) + assert mode == ImpressionsMode.NONE + assert rate == 200 + + mode, rate = config._sanitize_impressions_mode('pluggable', 'OPTIMIZED', 200) assert mode == ImpressionsMode.OPTIMIZED assert rate == 200 - mode, rate = config._sanitize_impressions_mode(43, -1) + mode, rate = config._sanitize_impressions_mode('memory', 43, -1) assert mode == ImpressionsMode.OPTIMIZED assert rate == 60 - mode, rate = config._sanitize_impressions_mode('OPTIMIZED') + mode, rate = config._sanitize_impressions_mode('memory', 'OPTIMIZED') assert mode == ImpressionsMode.OPTIMIZED assert rate == 300 - mode, rate = config._sanitize_impressions_mode('DEBUG') + mode, rate = config._sanitize_impressions_mode('memory', 'DEBUG') assert mode == ImpressionsMode.DEBUG assert rate == 60 - def test_sanitize(self): + def test_sanitize(self, mocker): """Test sanitization.""" + _logger = mocker.Mock() + mocker.patch('splitio.client.config._LOGGER', new=_logger) configs = {} processed = config.sanitize('some', configs) - assert processed['redisLocalCacheEnabled'] # check default is True + assert processed['flagSetsFilter'] is None + assert processed['httpAuthenticateScheme'] is config.AuthenticateScheme.NONE + + processed = config.sanitize('some', {'redisHost': 'x', 'flagSetsFilter': ['set']}) + assert processed['flagSetsFilter'] is None + + processed = config.sanitize('some', {'storageType': 'pluggable', 'flagSetsFilter': ['set']}) + assert processed['flagSetsFilter'] is None + + processed = config.sanitize('some', {'httpAuthenticateScheme': 'KERBEROS_spnego'}) + assert processed['httpAuthenticateScheme'] is config.AuthenticateScheme.KERBEROS_SPNEGO + + processed = config.sanitize('some', {'httpAuthenticateScheme': 'kerberos_proxy'}) + assert processed['httpAuthenticateScheme'] is config.AuthenticateScheme.KERBEROS_PROXY + + processed = config.sanitize('some', {'httpAuthenticateScheme': 'anything'}) + assert processed['httpAuthenticateScheme'] is config.AuthenticateScheme.NONE + + processed = config.sanitize('some', {'httpAuthenticateScheme': 'NONE'}) + assert processed['httpAuthenticateScheme'] is config.AuthenticateScheme.NONE + + _logger.reset_mock() + processed = config.sanitize('some', {'fallbackTreatments': 'NONE'}) + assert processed['fallbackTreatments'] == None + assert _logger.warning.mock_calls[1] == mocker.call("Config: fallbackTreatments parameter should be of `FallbackTreatmentsConfiguration` class.") + + _logger.reset_mock() + processed = config.sanitize('some', {'fallbackTreatments': FallbackTreatmentsConfiguration(123)}) + assert processed['fallbackTreatments'].global_fallback_treatment == None + assert _logger.warning.mock_calls[1] == mocker.call("Config: global fallbacktreatment parameter is discarded.") + + _logger.reset_mock() + processed = config.sanitize('some', {'fallbackTreatments': FallbackTreatmentsConfiguration(FallbackTreatment(123))}) + assert processed['fallbackTreatments'].global_fallback_treatment == None + assert _logger.warning.mock_calls[1] == mocker.call("Config: global fallbacktreatment parameter is discarded.") + + fb = FallbackTreatmentsConfiguration(FallbackTreatment('on')) + processed = config.sanitize('some', {'fallbackTreatments': fb}) + assert processed['fallbackTreatments'].global_fallback_treatment.treatment == fb.global_fallback_treatment.treatment + assert processed['fallbackTreatments'].global_fallback_treatment.label == None + + fb = FallbackTreatmentsConfiguration(FallbackTreatment('on'), {"flag": FallbackTreatment("off")}) + processed = config.sanitize('some', {'fallbackTreatments': fb}) + assert processed['fallbackTreatments'].global_fallback_treatment.treatment == fb.global_fallback_treatment.treatment + assert processed['fallbackTreatments'].by_flag_fallback_treatment["flag"] == fb.by_flag_fallback_treatment["flag"] + assert processed['fallbackTreatments'].by_flag_fallback_treatment["flag"].label == None + + _logger.reset_mock() + fb = FallbackTreatmentsConfiguration(None, {"flag#%": FallbackTreatment("off"), "flag2": FallbackTreatment("on")}) + processed = config.sanitize('some', {'fallbackTreatments': fb}) + assert len(processed['fallbackTreatments'].by_flag_fallback_treatment) == 1 + assert processed['fallbackTreatments'].by_flag_fallback_treatment.get("flag2") == fb.by_flag_fallback_treatment["flag2"] + assert _logger.warning.mock_calls[1] == mocker.call('Config: fallback treatment parameter for feature flag %s is discarded.', 'flag#%') \ No newline at end of file diff --git a/tests/client/test_factory.py b/tests/client/test_factory.py index 065584e8..1512507c 100644 --- a/tests/client/test_factory.py +++ b/tests/client/test_factory.py @@ -5,42 +5,94 @@ import os import time import threading -from splitio.client.factory import get_factory, SplitFactory, _INSTANTIATED_FACTORIES, Status,\ - _LOGGER as _logger +import pytest +import queue + +from splitio.optional.loaders import asyncio +from splitio.client.factory import get_factory, get_factory_async, SplitFactory, _INSTANTIATED_FACTORIES, Status,\ + _LOGGER as _logger, SplitFactoryAsync from splitio.client.config import DEFAULT_CONFIG -from splitio.storage import redis, inmemmory -from splitio.tasks import events_sync, impressions_sync, split_sync, segment_sync -from splitio.tasks.util import asynctask -from splitio.api.splits import SplitsAPI -from splitio.api.segments import SegmentsAPI -from splitio.api.impressions import ImpressionsAPI -from splitio.api.events import EventsAPI -from splitio.engine.impressions import Manager as ImpressionsManager -from splitio.sync.manager import Manager -from splitio.sync.synchronizer import Synchronizer, SplitSynchronizers, SplitTasks -from splitio.sync.split import SplitSynchronizer -from splitio.sync.segment import SegmentSynchronizer -from splitio.recorder.recorder import PipelinedRecorder, StandardRecorder +from splitio.engine.telemetry import TelemetryStorageConsumer, TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.engine.impressions.impressions import Manager as ImpressionsManager +from splitio.engine.impressions.manager import Counter as ImpressionsCounter +from splitio.engine.impressions.unique_keys_tracker import UniqueKeysTracker, UniqueKeysTrackerAsync +from splitio.engine.telemetry import TelemetryStorageConsumer, TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.engine.evaluator import Evaluator, EvaluationContext +from splitio.engine.impressions.strategies import StrategyDebugMode, StrategyNoneMode, StrategyOptimizedMode +from splitio.events.events_task import EventsTask +from splitio.events.events_manager import EventsManagerAsync +from splitio.models.splits import from_raw +from splitio.models.fallback_config import FallbackTreatmentsConfiguration, FallbackTreatmentCalculator +from splitio.models.fallback_treatment import FallbackTreatment +from splitio.models.events import SdkInternalEvent +from splitio.recorder.recorder import PipelinedRecorder, StandardRecorder, StandardRecorderAsync +from splitio.storage import redis, inmemmory, pluggable, EventStorage +from splitio.storage.inmemmory import InMemorySplitStorage, InMemorySegmentStorage, \ + InMemoryImpressionStorage, InMemoryTelemetryStorage, InMemorySplitStorageAsync, \ + InMemoryImpressionStorageAsync, InMemorySegmentStorageAsync, InMemoryTelemetryStorageAsync, InMemoryEventStorageAsync, \ + InMemoryRuleBasedSegmentStorage, InMemoryRuleBasedSegmentStorageAsync +from splitio.sync.manager import Manager, ManagerAsync +from splitio.sync.synchronizer import Synchronizer, SynchronizerAsync, SplitSynchronizers, SplitTasks +from splitio.sync.split import SplitSynchronizer, SplitSynchronizerAsync +from splitio.sync.segment import SegmentSynchronizer, SegmentSynchronizerAsync from splitio.storage.adapters.redis import RedisAdapter, RedisPipelineAdapter +from splitio.tasks.util import asynctask +from tests.storage.test_pluggable import StorageMockAdapter, StorageMockAdapterAsync +from tests.integration import splits_json class SplitFactoryTests(object): """Split factory test cases.""" + def test_flag_sets_counts(self): + factory = get_factory("none", config={ + 'flagSetsFilter': ['set1', 'set2', 'set3'] + }) + + assert factory._telemetry_init_producer._telemetry_storage._tel_config._flag_sets == 3 + assert factory._telemetry_init_producer._telemetry_storage._tel_config._flag_sets_invalid == 0 + event = threading.Event() + factory.destroy(event) + event.wait() + + factory = get_factory("none", config={ + 'flagSetsFilter': ['s#et1', 'set2', 'set3'] + }) + assert factory._telemetry_init_producer._telemetry_storage._tel_config._flag_sets == 3 + assert factory._telemetry_init_producer._telemetry_storage._tel_config._flag_sets_invalid == 1 + event = threading.Event() + factory.destroy(event) + event.wait() + + factory = get_factory("none", config={ + 'flagSetsFilter': ['s#et1', 22, 'set3'] + }) + assert factory._telemetry_init_producer._telemetry_storage._tel_config._flag_sets == 3 + assert factory._telemetry_init_producer._telemetry_storage._tel_config._flag_sets_invalid == 2 + event = threading.Event() + factory.destroy(event) + event.wait() + def test_inmemory_client_creation_streaming_false(self, mocker): """Test that a client with in-memory storage is created correctly.""" - # Setup synchronizer - def _split_synchronizer(self, ready_flag, some, auth_api, streaming_enabled, sdk_matadata, sse_url=None, client_key=None): + def _split_synchronizer(self, ready_flag, some, auth_api, streaming_enabled, sdk_matadata, telemetry_runtime_producer, sse_url=None, client_key=None): synchronizer = mocker.Mock(spec=Synchronizer) synchronizer.sync_all.return_values = None self._ready_flag = ready_flag self._synchronizer = synchronizer self._streaming_enabled = False + self._telemetry_runtime_producer = telemetry_runtime_producer + mocker.patch('splitio.sync.manager.Manager.__init__', new=_split_synchronizer) # Start factory and make assertions factory = get_factory('some_api_key') + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + assert isinstance(factory._storages['splits'], inmemmory.InMemorySplitStorage) assert isinstance(factory._storages['segments'], inmemmory.InMemorySegmentStorage) assert isinstance(factory._storages['impressions'], inmemmory.InMemoryImpressionStorage) @@ -49,14 +101,16 @@ def _split_synchronizer(self, ready_flag, some, auth_api, streaming_enabled, sdk assert factory._storages['events']._events.maxsize == 10000 assert isinstance(factory._sync_manager, Manager) - assert isinstance(factory._recorder, StandardRecorder) assert isinstance(factory._recorder._impressions_manager, ImpressionsManager) assert isinstance(factory._recorder._event_sotrage, inmemmory.EventStorage) assert isinstance(factory._recorder._impression_storage, inmemmory.ImpressionStorage) assert factory._labels_enabled is True - factory.block_until_ready() + try: + factory.block_until_ready(1) + except: + pass assert factory.ready factory.destroy() @@ -64,7 +118,7 @@ def test_redis_client_creation(self, mocker): """Test that a client with redis storage is created correctly.""" strict_redis_mock = mocker.Mock() mocker.patch('splitio.storage.adapters.redis.StrictRedis', new=strict_redis_mock) - + fallback_treatments_configuration = FallbackTreatmentsConfiguration(FallbackTreatment("on")) config = { 'labelsEnabled': False, 'impressionListener': 123, @@ -72,6 +126,7 @@ def test_redis_client_creation(self, mocker): 'redisPort': 1234, 'redisDb': 1, 'redisPassword': 'some_password', + 'redisUsername': 'redis_user', 'redisSocketTimeout': 123, 'redisSocketConnectTimeout': 123, 'redisSocketKeepalive': 123, @@ -79,7 +134,6 @@ def test_redis_client_creation(self, mocker): 'redisConnectionPool': False, 'redisUnixSocketPath': '/some_path', 'redisEncodingErrors': 'non-strict', - 'redisErrors': True, 'redisDecodeResponses': True, 'redisRetryOnTimeout': True, 'redisSsl': True, @@ -88,24 +142,33 @@ def test_redis_client_creation(self, mocker): 'redisSslCertReqs': 'some_cert_req', 'redisSslCaCerts': 'some_ca_cert', 'redisMaxConnections': 999, + 'flagSetsFilter': ['set_1'], + 'fallbackTreatments': fallback_treatments_configuration } factory = get_factory('some_api_key', config=config) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + assert isinstance(factory._get_storage('splits'), redis.RedisSplitStorage) assert isinstance(factory._get_storage('segments'), redis.RedisSegmentStorage) assert isinstance(factory._get_storage('impressions'), redis.RedisImpressionsStorage) assert isinstance(factory._get_storage('events'), redis.RedisEventsStorage) - assert factory._sync_manager is None + assert factory._get_storage('splits').flag_set_filter.flag_sets == set([]) + assert factory._fallback_treatment_calculator.fallback_treatments_configuration.global_fallback_treatment.treatment == fallback_treatments_configuration.global_fallback_treatment.treatment adapter = factory._get_storage('splits')._redis assert adapter == factory._get_storage('segments')._redis assert adapter == factory._get_storage('impressions')._redis assert adapter == factory._get_storage('events')._redis - assert strict_redis_mock.mock_calls == [mocker.call( + assert strict_redis_mock.mock_calls[0] == mocker.call( host='some_host', port=1234, db=1, + username='redis_user', password='some_password', socket_timeout=123, socket_connect_timeout=123, @@ -115,7 +178,6 @@ def test_redis_client_creation(self, mocker): unix_socket_path='/some_path', encoding='utf-8', encoding_errors='non-strict', - errors=True, decode_responses=True, retry_on_timeout=True, ssl=True, @@ -123,36 +185,20 @@ def test_redis_client_creation(self, mocker): ssl_certfile='some_cert_file', ssl_cert_reqs='some_cert_req', ssl_ca_certs='some_ca_cert', - max_connections=999 - )] + max_connections=999, + ) assert factory._labels_enabled is False assert isinstance(factory._recorder, PipelinedRecorder) assert isinstance(factory._recorder._impressions_manager, ImpressionsManager) assert isinstance(factory._recorder._make_pipe(), RedisPipelineAdapter) assert isinstance(factory._recorder._event_sotrage, redis.RedisEventsStorage) assert isinstance(factory._recorder._impression_storage, redis.RedisImpressionsStorage) - factory.block_until_ready() - assert factory.ready - factory.destroy() - def test_uwsgi_forked_client_creation(self): - """Test client with preforked initialization.""" - factory = get_factory('some_api_key', config={'preforkedInitialization': True}) - assert isinstance(factory._storages['splits'], inmemmory.InMemorySplitStorage) - assert isinstance(factory._storages['segments'], inmemmory.InMemorySegmentStorage) - assert isinstance(factory._storages['impressions'], inmemmory.InMemoryImpressionStorage) - assert factory._storages['impressions']._impressions.maxsize == 10000 - assert isinstance(factory._storages['events'], inmemmory.InMemoryEventStorage) - assert factory._storages['events']._events.maxsize == 10000 - - assert isinstance(factory._sync_manager, Manager) - - assert isinstance(factory._recorder, StandardRecorder) - assert isinstance(factory._recorder._impressions_manager, ImpressionsManager) - assert isinstance(factory._recorder._event_sotrage, inmemmory.EventStorage) - assert isinstance(factory._recorder._impression_storage, inmemmory.ImpressionStorage) - - assert factory._status == Status.WAITING_FORK + try: + factory.block_until_ready(1) + except: + pass + assert factory.ready factory.destroy() def test_destroy(self, mocker): @@ -204,26 +250,44 @@ def _imppression_count_task_init_mock(self, synchronize_counters): mocker.patch('splitio.client.factory.ImpressionsCountSyncTask.__init__', new=_imppression_count_task_init_mock) + telemetry_async_task_mock = mocker.Mock(spec=asynctask.AsyncTask) + telemetry_async_task_mock.stop.side_effect = stop_mock + + def _telemetry_task_init_mock(self, synchronize_telemetry, synchronize_telemetry2): + self._task = telemetry_async_task_mock + mocker.patch('splitio.client.factory.TelemetrySyncTask.__init__', + new=_telemetry_task_init_mock) + split_sync = mocker.Mock(spec=SplitSynchronizer) - split_sync.synchronize_splits.return_values = None + split_sync.synchronize_splits.return_value = [] segment_sync = mocker.Mock(spec=SegmentSynchronizer) segment_sync.synchronize_segments.return_values = None syncs = SplitSynchronizers(split_sync, segment_sync, mocker.Mock(), - mocker.Mock(), mocker.Mock()) + mocker.Mock(), mocker.Mock(), mocker.Mock()) tasks = SplitTasks(split_async_task_mock, segment_async_task_mock, imp_async_task_mock, - evt_async_task_mock, imp_count_async_task_mock) + evt_async_task_mock, imp_count_async_task_mock, telemetry_async_task_mock) # Setup synchronizer - def _split_synchronizer(self, ready_flag, some, auth_api, streaming_enabled, sdk_matadata, sse_url=None, client_key=None): + def _split_synchronizer(self, ready_flag, some, auth_api, streaming_enabled, sdk_matadata, telemetry_runtime_producer, sse_url=None, client_key=None): synchronizer = Synchronizer(syncs, tasks) self._ready_flag = ready_flag self._synchronizer = synchronizer self._streaming_enabled = False + self._telemetry_runtime_producer = telemetry_runtime_producer mocker.patch('splitio.sync.manager.Manager.__init__', new=_split_synchronizer) # Start factory and make assertions + # Using invalid key should result in a timeout exception factory = get_factory('some_api_key') - factory.block_until_ready() + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + try: + factory.block_until_ready(1) + except: + pass assert factory.ready assert factory.destroyed is False @@ -287,32 +351,52 @@ def _imppression_count_task_init_mock(self, synchronize_counters): mocker.patch('splitio.client.factory.ImpressionsCountSyncTask.__init__', new=_imppression_count_task_init_mock) + telemetry_async_task_mock = mocker.Mock(spec=asynctask.AsyncTask) + telemetry_async_task_mock.stop.side_effect = stop_mock + + def _telemetry_task_init_mock(self, synchronize_telemetry, synchronize_telemetry2): + self._task = telemetry_async_task_mock + mocker.patch('splitio.client.factory.TelemetrySyncTask.__init__', + new=_telemetry_task_init_mock) + + internal_event_task_mock = mocker.Mock(spec=EventsTask) + internal_event_task_mock.stop.side_effect = stop_mock_2 + internal_event_task_mock.start.side_effect = stop_mock_2 + split_sync = mocker.Mock(spec=SplitSynchronizer) - split_sync.synchronize_splits.return_values = None + split_sync.synchronize_splits.return_value = [] segment_sync = mocker.Mock(spec=SegmentSynchronizer) segment_sync.synchronize_segments.return_values = None syncs = SplitSynchronizers(split_sync, segment_sync, mocker.Mock(), - mocker.Mock(), mocker.Mock()) + mocker.Mock(), mocker.Mock(), mocker.Mock()) tasks = SplitTasks(split_async_task_mock, segment_async_task_mock, imp_async_task_mock, - evt_async_task_mock, imp_count_async_task_mock) + evt_async_task_mock, imp_count_async_task_mock, telemetry_async_task_mock, None, None, internal_event_task_mock) # Setup synchronizer - def _split_synchronizer(self, ready_flag, some, auth_api, streaming_enabled, sdk_matadata, sse_url=None, client_key=None): + def _split_synchronizer(self, ready_flag, some, auth_api, streaming_enabled, sdk_matadata, telemetry_runtime_producer, sse_url=None, client_key=None): synchronizer = Synchronizer(syncs, tasks) self._ready_flag = ready_flag self._synchronizer = synchronizer self._streaming_enabled = False + self._telemetry_runtime_producer = telemetry_runtime_producer mocker.patch('splitio.sync.manager.Manager.__init__', new=_split_synchronizer) # Start factory and make assertions factory = get_factory('some_api_key') - factory.block_until_ready() - assert factory.ready + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + try: + factory.block_until_ready(1) + except: + pass + assert factory._status == Status.READY assert factory.destroyed is False event = threading.Event() factory.destroy(event) - assert not event.is_set() time.sleep(1) assert event.is_set() assert len(imp_async_task_mock.stop.mock_calls) == 1 @@ -322,7 +406,7 @@ def _split_synchronizer(self, ready_flag, some, auth_api, streaming_enabled, sdk def test_destroy_with_event_redis(self, mocker): def _make_factory_with_apikey(apikey, *_, **__): - return SplitFactory(apikey, {}, True, mocker.Mock(spec=ImpressionsManager), None) + return SplitFactory(apikey, {}, True, mocker.Mock(spec=ImpressionsManager), None, mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) factory_module_logger = mocker.Mock() build_redis = mocker.Mock() @@ -337,6 +421,11 @@ def _make_factory_with_apikey(apikey, *_, **__): } factory = get_factory("none", config=config) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + event = threading.Event() factory.destroy(event) event.wait() @@ -344,6 +433,11 @@ def _make_factory_with_apikey(apikey, *_, **__): assert len(build_redis.mock_calls) == 1 factory = get_factory("none", config=config) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + factory.destroy(None) time.sleep(0.1) assert factory.destroyed @@ -353,10 +447,12 @@ def test_multiple_factories(self, mocker): """Test multiple factories instantiation and tracking.""" sdk_ready_flag = threading.Event() - def _init(self, ready_flag, some, auth_api, streaming_enabled, sse_url=None): + def _init(self, ready_flag, some, auth_api, streaming_enabled, telemetry_runtime_producer, telemetry_init_consumer, sse_url=None): self._ready_flag = ready_flag self._synchronizer = mocker.Mock(spec=Synchronizer) self._streaming_enabled = False + self._telemetry_runtime_producer = telemetry_runtime_producer + self._telemetry_init_consumer = telemetry_init_consumer mocker.patch('splitio.sync.manager.Manager.__init__', new=_init) def _start(self, *args, **kwargs): @@ -367,10 +463,10 @@ def _stop(self, *args, **kwargs): pass mocker.patch('splitio.sync.manager.Manager.stop', new=_stop) - mockManager = Manager(sdk_ready_flag, mocker.Mock(), mocker.Mock(), False) + mockManager = Manager(sdk_ready_flag, mocker.Mock(), mocker.Mock(), False, mocker.Mock(), mocker.Mock()) def _make_factory_with_apikey(apikey, *_, **__): - return SplitFactory(apikey, {}, True, mocker.Mock(spec=ImpressionsManager), mockManager) + return SplitFactory(apikey, {}, True, mocker.Mock(spec=StandardRecorder), mocker.Mock(), mocker.Mock(), mockManager, mocker.Mock(), mocker.Mock(), mocker.Mock()) factory_module_logger = mocker.Mock() build_in_memory = mocker.Mock() @@ -387,13 +483,23 @@ def _make_factory_with_apikey(apikey, *_, **__): _INSTANTIATED_FACTORIES.clear() # Clear all factory counters for testing purposes factory1 = get_factory('some_api_key') + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory1._telemetry_submitter = TelemetrySubmitterMock() + assert _INSTANTIATED_FACTORIES['some_api_key'] == 1 assert factory_module_logger.warning.mock_calls == [] factory2 = get_factory('some_api_key') + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory2._telemetry_submitter = TelemetrySubmitterMock() + assert _INSTANTIATED_FACTORIES['some_api_key'] == 2 assert factory_module_logger.warning.mock_calls == [mocker.call( - "factory instantiation: You already have %d %s with this API Key. " + "factory instantiation: You already have %d %s with this SDK Key. " "We recommend keeping only one instance of the factory at all times " "(Singleton pattern) and reusing it throughout your application.", 1, @@ -402,9 +508,14 @@ def _make_factory_with_apikey(apikey, *_, **__): factory_module_logger.reset_mock() factory3 = get_factory('some_api_key') + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory3._telemetry_submitter = TelemetrySubmitterMock() + assert _INSTANTIATED_FACTORIES['some_api_key'] == 3 assert factory_module_logger.warning.mock_calls == [mocker.call( - "factory instantiation: You already have %d %s with this API Key. " + "factory instantiation: You already have %d %s with this SDK Key. " "We recommend keeping only one instance of the factory at all times " "(Singleton pattern) and reusing it throughout your application.", 2, @@ -413,6 +524,11 @@ def _make_factory_with_apikey(apikey, *_, **__): factory_module_logger.reset_mock() factory4 = get_factory('some_other_api_key') + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory4._telemetry_submitter = TelemetrySubmitterMock() + assert _INSTANTIATED_FACTORIES['some_api_key'] == 3 assert _INSTANTIATED_FACTORIES['some_other_api_key'] == 1 assert factory_module_logger.warning.mock_calls == [mocker.call( @@ -427,9 +543,15 @@ def _make_factory_with_apikey(apikey, *_, **__): event.wait() assert _INSTANTIATED_FACTORIES['some_other_api_key'] == 1 assert _INSTANTIATED_FACTORIES['some_api_key'] == 2 - factory2.destroy() - factory3.destroy() - factory4.destroy() + event = threading.Event() + factory2.destroy(event) + event.wait() + event = threading.Event() + factory3.destroy(event) + event.wait() + event = threading.Event() + factory4.destroy(event) + event.wait() def test_uwsgi_preforked(self, mocker): """Test preforked initializations.""" @@ -472,7 +594,15 @@ def _get_storage_mock(self, name): 'preforkedInitialization': True, } factory = get_factory("none", config=config) - factory.block_until_ready(10) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + try: + factory.block_until_ready(10) + except: + pass assert factory._status == Status.WAITING_FORK assert len(sync_all_mock.mock_calls) == 1 assert len(start_mock.mock_calls) == 0 @@ -483,6 +613,9 @@ def _get_storage_mock(self, name): assert clear_impressions._called == 1 assert clear_events._called == 1 + factory.destroy() + time.sleep(0.1) + assert factory.destroyed def test_error_prefork(self, mocker): """Test not handling fork.""" @@ -492,9 +625,490 @@ def test_error_prefork(self, mocker): filename = os.path.join(os.path.dirname(__file__), '../integration/files', 'file2.yaml') factory = get_factory('localhost', config={'splitFile': filename}) - factory.block_until_ready(1) - + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + try: + factory.block_until_ready(1) + except: + pass _logger = mocker.Mock() mocker.patch('splitio.client.factory._LOGGER', new=_logger) factory.resume() assert _logger.warning.mock_calls == expected_msg + factory.destroy() + + def test_pluggable_client_creation(self, mocker): + """Test that a client with pluggable storage is created correctly.""" + config = { + 'labelsEnabled': False, + 'impressionListener': 123, + 'storageType': 'pluggable', + 'storageWrapper': StorageMockAdapter(), + 'flagSetsFilter': ['set_1'] + } + factory = get_factory('some_api_key', config=config) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + assert isinstance(factory._get_storage('splits'), pluggable.PluggableSplitStorage) + assert isinstance(factory._get_storage('segments'), pluggable.PluggableSegmentStorage) + assert isinstance(factory._get_storage('impressions'), pluggable.PluggableImpressionsStorage) + assert isinstance(factory._get_storage('events'), pluggable.PluggableEventsStorage) + assert factory._get_storage('splits').flag_set_filter.flag_sets == set([]) + + adapter = factory._get_storage('splits')._pluggable_adapter + assert adapter == factory._get_storage('segments')._pluggable_adapter + assert adapter == factory._get_storage('impressions')._pluggable_adapter + assert adapter == factory._get_storage('events')._pluggable_adapter + + assert factory._labels_enabled is False + assert isinstance(factory._recorder, StandardRecorder) + assert isinstance(factory._recorder._impressions_manager, ImpressionsManager) + assert isinstance(factory._recorder._event_sotrage, pluggable.PluggableEventsStorage) + assert isinstance(factory._recorder._impression_storage, pluggable.PluggableImpressionsStorage) + + try: + factory.block_until_ready(1) + except: + pass + assert factory.ready + factory.destroy() + time.sleep(0.1) + assert factory.destroyed + + def test_destroy_with_event_pluggable(self, mocker): + config = { + 'labelsEnabled': False, + 'impressionListener': 123, + 'storageType': 'pluggable', + 'storageWrapper': StorageMockAdapter() + } + + factory = get_factory("none", config=config) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + event = threading.Event() + factory.destroy(event) + event.wait() + assert factory.destroyed + + factory = get_factory("none", config=config) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + factory.destroy(None) + time.sleep(0.1) + assert factory.destroyed + + def test_internal_ready_event_notification(self, mocker): + """Test that a client with in-memory storage is sending internal events correctly.""" + # Setup synchronizer + def _split_synchronizer(self, ready_flag, some, auth_api, streaming_enabled, sdk_matadata, telemetry_runtime_producer, sse_url=None, client_key=None): + synchronizer = mocker.Mock(spec=Synchronizer) + synchronizer.sync_all.return_values = None + self._ready_flag = ready_flag + self._synchronizer = synchronizer + self._streaming_enabled = False + self._telemetry_runtime_producer = telemetry_runtime_producer + + mocker.patch('splitio.sync.manager.Manager.__init__', new=_split_synchronizer) + + # Start factory and make assertions + + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + impression_storage = InMemoryImpressionStorage(10, telemetry_runtime_producer) + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) + event_storage = mocker.Mock(spec=EventStorage) + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + recorder = StandardRecorder(impmanager, event_storage, impression_storage, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory("some key", + {'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': impression_storage, + 'events': event_storage}, + mocker.Mock(), + recorder, + events_queue, + mocker.Mock(), + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + try: + factory.block_until_ready(1) + except: + pass + + assert factory.ready + event = events_queue.get() + assert event.internal_event == SdkInternalEvent.SDK_READY + assert event.metadata == None + factory.destroy() + + def test_uwsgi_forked_client_creation(self): + """Test client with preforked initialization.""" + # Invalid API Key with preforked should exit after 3 attempts. + factory = get_factory('some_api_key', config={'preforkedInitialization': True}) + class TelemetrySubmitterMock(): + def synchronize_config(*_): + pass + factory._telemetry_submitter = TelemetrySubmitterMock() + + assert isinstance(factory._storages['splits'], inmemmory.InMemorySplitStorage) + assert isinstance(factory._storages['segments'], inmemmory.InMemorySegmentStorage) + assert isinstance(factory._storages['impressions'], inmemmory.InMemoryImpressionStorage) + assert factory._storages['impressions']._impressions.maxsize == 10000 + assert isinstance(factory._storages['events'], inmemmory.InMemoryEventStorage) + assert factory._storages['events']._events.maxsize == 10000 + + assert isinstance(factory._sync_manager, Manager) + + assert isinstance(factory._recorder, StandardRecorder) + assert isinstance(factory._recorder._impressions_manager, ImpressionsManager) + assert isinstance(factory._recorder._event_sotrage, inmemmory.EventStorage) + assert isinstance(factory._recorder._impression_storage, inmemmory.ImpressionStorage) + + assert factory._status == Status.WAITING_FORK + factory.destroy() + time.sleep(0.1) + assert factory.destroyed + +class SplitFactoryAsyncTests(object): + """Split factory async test cases.""" + + @pytest.mark.asyncio + async def test_flag_sets_counts(self): + fallback_treatments_configuration = FallbackTreatmentsConfiguration(FallbackTreatment("on")) + factory = await get_factory_async("none", config={ + 'flagSetsFilter': ['set1', 'set2', 'set3'], + 'streamEnabled': False, + 'fallbackTreatments': fallback_treatments_configuration + }) + assert factory._fallback_treatment_calculator.fallback_treatments_configuration.global_fallback_treatment.treatment == fallback_treatments_configuration.global_fallback_treatment.treatment + assert factory._telemetry_init_producer._telemetry_storage._tel_config._flag_sets == 3 + assert factory._telemetry_init_producer._telemetry_storage._tel_config._flag_sets_invalid == 0 + await factory.destroy() + + factory = await get_factory_async("none", config={ + 'flagSetsFilter': ['s#et1', 'set2', 'set3'] + }) + assert factory._telemetry_init_producer._telemetry_storage._tel_config._flag_sets == 3 + assert factory._telemetry_init_producer._telemetry_storage._tel_config._flag_sets_invalid == 1 + await factory.destroy() + + factory = await get_factory_async("none", config={ + 'flagSetsFilter': ['s#et1', 22, 'set3'] + }) + assert factory._telemetry_init_producer._telemetry_storage._tel_config._flag_sets == 3 + assert factory._telemetry_init_producer._telemetry_storage._tel_config._flag_sets_invalid == 2 + await factory.destroy() + + @pytest.mark.asyncio + async def test_inmemory_client_creation_streaming_false_async(self, mocker): + """Test that a client with in-memory storage is created correctly for async.""" + # Setup synchronizer + def _split_synchronizer(self, ready_flag, some, auth_api, streaming_enabled, sdk_matadata, telemetry_runtime_producer, sse_url=None, client_key=None): + synchronizer = mocker.Mock(spec=SynchronizerAsync) + async def sync_all(*_): + return None + synchronizer.sync_all = sync_all + + def start_periodic_fetching(): + pass + synchronizer.start_periodic_fetching = start_periodic_fetching + + self._ready_flag = ready_flag + self._synchronizer = synchronizer + self._streaming_enabled = False + self._telemetry_runtime_producer = telemetry_runtime_producer + + mocker.patch('splitio.sync.manager.ManagerAsync.__init__', new=_split_synchronizer) + + async def synchronize_config(*_): + pass + mocker.patch('splitio.sync.telemetry.InMemoryTelemetrySubmitterAsync.synchronize_config', new=synchronize_config) + + # Start factory and make assertions + factory2 = await get_factory_async('some_api_key', config={'streamingEmabled': False}) + + assert isinstance(factory2, SplitFactoryAsync) + assert isinstance(factory2._storages['splits'], inmemmory.InMemorySplitStorageAsync) + assert isinstance(factory2._storages['segments'], inmemmory.InMemorySegmentStorageAsync) + assert isinstance(factory2._storages['impressions'], inmemmory.InMemoryImpressionStorageAsync) + assert factory2._storages['impressions']._impressions.maxsize == 10000 + assert isinstance(factory2._storages['events'], inmemmory.InMemoryEventStorageAsync) + assert factory2._storages['events']._events.maxsize == 10000 + + assert isinstance(factory2._sync_manager, ManagerAsync) + + assert isinstance(factory2._recorder, StandardRecorderAsync) + assert isinstance(factory2._recorder._impressions_manager, ImpressionsManager) + assert isinstance(factory2._recorder._event_sotrage, inmemmory.EventStorage) + assert isinstance(factory2._recorder._impression_storage, inmemmory.ImpressionStorage) + + assert factory2._labels_enabled is True + try: + await factory2.block_until_ready(1) + except: + pass + assert factory2._status == Status.READY + await factory2.destroy() + + @pytest.mark.asyncio + async def test_destroy_async(self, mocker): + """Test that tasks are shutdown and data is flushed when destroy is called.""" + + async def stop_mock(): + return + + split_async_task_mock = mocker.Mock(spec=asynctask.AsyncTaskAsync) + split_async_task_mock.stop.side_effect = stop_mock + + def _split_task_init_mock(self, synchronize_splits, period): + self._task = split_async_task_mock + self._period = period + mocker.patch('splitio.client.factory.SplitSynchronizationTaskAsync.__init__', + new=_split_task_init_mock) + + segment_async_task_mock = mocker.Mock(spec=asynctask.AsyncTaskAsync) + segment_async_task_mock.stop.side_effect = stop_mock + + def _segment_task_init_mock(self, synchronize_segments, period): + self._task = segment_async_task_mock + self._period = period + mocker.patch('splitio.client.factory.SegmentSynchronizationTaskAsync.__init__', + new=_segment_task_init_mock) + + imp_async_task_mock = mocker.Mock(spec=asynctask.AsyncTaskAsync) + imp_async_task_mock.stop.side_effect = stop_mock + + def _imppression_task_init_mock(self, synchronize_impressions, period): + self._period = period + self._task = imp_async_task_mock + mocker.patch('splitio.client.factory.ImpressionsSyncTaskAsync.__init__', + new=_imppression_task_init_mock) + + evt_async_task_mock = mocker.Mock(spec=asynctask.AsyncTaskAsync) + evt_async_task_mock.stop.side_effect = stop_mock + + def _event_task_init_mock(self, synchronize_events, period): + self._period = period + self._task = evt_async_task_mock + mocker.patch('splitio.client.factory.EventsSyncTaskAsync.__init__', new=_event_task_init_mock) + + imp_count_async_task_mock = mocker.Mock(spec=asynctask.AsyncTaskAsync) + imp_count_async_task_mock.stop.side_effect = stop_mock + + def _imppression_count_task_init_mock(self, synchronize_counters): + self._task = imp_count_async_task_mock + mocker.patch('splitio.client.factory.ImpressionsCountSyncTaskAsync.__init__', + new=_imppression_count_task_init_mock) + + telemetry_async_task_mock = mocker.Mock(spec=asynctask.AsyncTaskAsync) + telemetry_async_task_mock.stop.side_effect = stop_mock + + def _telemetry_task_init_mock(self, synchronize_telemetry, synchronize_telemetry2): + self._task = telemetry_async_task_mock + mocker.patch('splitio.client.factory.TelemetrySyncTaskAsync.__init__', + new=_telemetry_task_init_mock) + + split_sync = mocker.Mock(spec=SplitSynchronizerAsync) + async def synchronize_splits(*_): + return [] + split_sync.synchronize_splits = synchronize_splits + + segment_sync = mocker.Mock(spec=SegmentSynchronizerAsync) + async def synchronize_segments(*_): + return True + segment_sync.synchronize_segments = synchronize_segments + + syncs = SplitSynchronizers(split_sync, segment_sync, mocker.Mock(), + mocker.Mock(), mocker.Mock(), mocker.Mock()) + tasks = SplitTasks(split_async_task_mock, segment_async_task_mock, imp_async_task_mock, + evt_async_task_mock, imp_count_async_task_mock, telemetry_async_task_mock) + + # Setup synchronizer + def _split_synchronizer(self, ready_flag, some, auth_api, streaming_enabled, sdk_matadata, telemetry_runtime_producer, sse_url=None, client_key=None): + synchronizer = SynchronizerAsync(syncs, tasks) + self._ready_flag = ready_flag + self._synchronizer = synchronizer + self._streaming_enabled = False + self._telemetry_runtime_producer = telemetry_runtime_producer + mocker.patch('splitio.sync.manager.ManagerAsync.__init__', new=_split_synchronizer) + + async def synchronize_config(*_): + pass + mocker.patch('splitio.sync.telemetry.InMemoryTelemetrySubmitterAsync.synchronize_config', new=synchronize_config) + # Start factory and make assertions + # Using invalid key should result in a timeout exception + factory = await get_factory_async('some_api_key') + self.manager_called = False + async def stop(*_): + self.manager_called = True + pass + factory._sync_manager.stop = stop + + async def start(*_): + pass + factory._sync_manager.start = start + + try: + await factory.block_until_ready(1) + except: + pass + assert factory._status == Status.READY + assert factory.destroyed is False + + await factory.destroy() + assert self.manager_called + assert factory.destroyed is True + + @pytest.mark.asyncio + async def test_pluggable_client_creation_async(self, mocker): + """Test that a client with pluggable storage is created correctly.""" + config = { + 'labelsEnabled': False, + 'impressionListener': 123, + 'featuresRefreshRate': 1, + 'segmentsRefreshRate': 1, + 'metricsRefreshRate': 1, + 'impressionsRefreshRate': 1, + 'eventsPushRate': 1, + 'storageType': 'pluggable', + 'storageWrapper': StorageMockAdapterAsync() + } + factory = await get_factory_async('some_api_key', config=config) + assert isinstance(factory._get_storage('splits'), pluggable.PluggableSplitStorageAsync) + assert isinstance(factory._get_storage('segments'), pluggable.PluggableSegmentStorageAsync) + assert isinstance(factory._get_storage('impressions'), pluggable.PluggableImpressionsStorageAsync) + assert isinstance(factory._get_storage('events'), pluggable.PluggableEventsStorageAsync) + + adapter = factory._get_storage('splits')._pluggable_adapter + assert adapter == factory._get_storage('segments')._pluggable_adapter + assert adapter == factory._get_storage('impressions')._pluggable_adapter + assert adapter == factory._get_storage('events')._pluggable_adapter + + assert factory._labels_enabled is False + assert isinstance(factory._recorder, StandardRecorderAsync) + assert isinstance(factory._recorder._impressions_manager, ImpressionsManager) + assert isinstance(factory._recorder._event_sotrage, pluggable.PluggableEventsStorageAsync) + assert isinstance(factory._recorder._impression_storage, pluggable.PluggableImpressionsStorageAsync) + try: + await factory.block_until_ready(1) + except: + pass + assert factory._status == Status.READY + await factory.destroy() + + @pytest.mark.asyncio + async def test_destroy_redis_async(self, mocker): + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + async def _make_factory_with_apikey(apikey, *_, **__): + return SplitFactoryAsync(apikey, {}, True, mocker.Mock(), internal_events_queue, events_manager, mocker.Mock(spec=ManagerAsync), mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + + factory_module_logger = mocker.Mock() + build_redis = mocker.Mock() + build_redis.side_effect = _make_factory_with_apikey + mocker.patch('splitio.client.factory._LOGGER', new=factory_module_logger) + mocker.patch('splitio.client.factory._build_redis_factory_async', new=build_redis) + + config = { + 'redisDb': 0, + 'redisHost': 'localhost', + 'redisPosrt': 6379, + } + factory = await get_factory_async("none", config=config) + await factory.destroy() + assert factory.destroyed + assert len(build_redis.mock_calls) == 1 + + factory = await get_factory_async("none", config=config) + await factory.destroy() + await asyncio.sleep(0.5) + assert factory.destroyed + assert len(build_redis.mock_calls) == 2 + + @pytest.mark.asyncio + async def test_internal_ready_event_notification(self, mocker): + """Test that a client with in-memory storage is sending internal events correctly.""" + # Setup synchronizer + def _split_synchronizer(self, ready_flag, some, auth_api, streaming_enabled, sdk_matadata, telemetry_runtime_producer, sse_url=None, client_key=None): + synchronizer = mocker.Mock(spec=SynchronizerAsync) + async def sync_all(*_): + return None + synchronizer.sync_all = sync_all + + def start_periodic_fetching(): + pass + synchronizer.start_periodic_fetching = start_periodic_fetching + + def start_periodic_data_recording(): + pass + synchronizer.start_periodic_data_recording = start_periodic_data_recording + + self._ready_flag = ready_flag + self._synchronizer = synchronizer + self._streaming_enabled = False + self._telemetry_runtime_producer = telemetry_runtime_producer + + mocker.patch('splitio.sync.manager.ManagerAsync.__init__', new=_split_synchronizer) + + async def synchronize_config(*_): + await asyncio.sleep(2) + pass + mocker.patch('splitio.sync.telemetry.InMemoryTelemetrySubmitterAsync.synchronize_config', new=synchronize_config) + + async def record_ready_time(*_): + pass + mocker.patch('splitio.models.telemetry.TelemetryConfigAsync.record_ready_time', new=record_ready_time) + + async def record_active_and_redundant_factories(*_): + pass + mocker.patch('splitio.models.telemetry.TelemetryConfigAsync.record_active_and_redundant_factories', new=record_active_and_redundant_factories) + + # Start factory and make assertions + factory = await get_factory_async('some_api_key', config={'streamingEmabled': False}) + for task in asyncio.all_tasks(): + if task._coro.__qualname__ == "EventsTaskAsync._run": + task.cancel() + try: + await factory.block_until_ready(3) + except: + pass + await asyncio.sleep(.2) + event = await factory._internal_events_queue.get() + assert event.internal_event == SdkInternalEvent.SDK_READY + assert event.metadata == None + await factory.destroy() \ No newline at end of file diff --git a/tests/client/test_input_validator.py b/tests/client/test_input_validator.py index 98416fe6..e1634f54 100644 --- a/tests/client/test_input_validator.py +++ b/tests/client/test_input_validator.py @@ -1,15 +1,23 @@ """Unit tests for the input_validator module.""" +import pytest import logging +import asyncio -from splitio.client.factory import SplitFactory, get_factory -from splitio.client.client import CONTROL, Client, _LOGGER as _logger -from splitio.client.manager import SplitManager +from splitio.client.factory import SplitFactory, get_factory, SplitFactoryAsync, get_factory_async +from splitio.client.client import CONTROL, Client, _LOGGER as _logger, ClientAsync from splitio.client.key import Key -from splitio.storage import SplitStorage, EventStorage, ImpressionStorage, SegmentStorage +from splitio.events.events_manager import EventsManagerAsync +from splitio.storage import SplitStorage, EventStorage, ImpressionStorage, SegmentStorage, RuleBasedSegmentsStorage +from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemoryTelemetryStorageAsync, \ + InMemorySplitStorage, InMemorySplitStorageAsync, InMemoryRuleBasedSegmentStorage, InMemoryRuleBasedSegmentStorageAsync from splitio.models.splits import Split +from splitio.models.fallback_config import FallbackTreatmentCalculator from splitio.client import input_validator -from splitio.recorder.recorder import StandardRecorder - +from splitio.client.manager import SplitManager, SplitManagerAsync +from splitio.recorder.recorder import StandardRecorder, StandardRecorderAsync +from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.engine.impressions.impressions import Manager as ImpressionManager +from splitio.models.fallback_treatment import FallbackTreatment class ClientInputValidationTests(object): """Input validation test cases.""" @@ -23,30 +31,43 @@ def test_get_treatment(self, mocker): conditions_mock = mocker.PropertyMock() conditions_mock.return_value = [] type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] storage_mock = mocker.Mock(spec=SplitStorage) - storage_mock.get.return_value = split_mock - - def _get_storage_mock(storage): - return { + storage_mock.fetch_many.return_value = {'some_feature': split_mock} + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorage) + rbs_storage.fetch_many.return_value = {} + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, mocker.Mock(spec=EventStorage), ImpressionStorage, telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + { 'splits': storage_mock, 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, 'impressions': mocker.Mock(spec=ImpressionStorage), 'events': mocker.Mock(spec=EventStorage), - }[storage] - factory_mock = mocker.Mock(spec=SplitFactory) - factory_mock._get_storage.side_effect = _get_storage_mock - factory_destroyed = mocker.PropertyMock() - factory_mock._waiting_fork.return_value = False - factory_destroyed.return_value = False - type(factory_mock).destroyed = factory_destroyed - - client = Client(factory_mock, mocker.Mock()) + }, + mocker.Mock(), + recorder, + mocker.Mock(), + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + + client = Client(factory, mocker.Mock(), mocker.Mock(), mocker.Mock(), FallbackTreatmentCalculator(None)) _logger = mocker.Mock() mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) assert client.get_treatment(None, 'some_feature') == CONTROL assert _logger.error.mock_calls == [ - mocker.call('%s: you passed a null key, key must be a non-empty string.', 'get_treatment') + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment', 'key', 'key') ] _logger.reset_mock() @@ -95,31 +116,31 @@ def _get_storage_mock(storage): _logger.reset_mock() assert client.get_treatment('some_key', None) == CONTROL assert _logger.error.mock_calls == [ - mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment', 'feature_name', 'feature_name') + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment', 'feature_flag_name', 'feature_flag_name') ] _logger.reset_mock() assert client.get_treatment('some_key', 123) == CONTROL assert _logger.error.mock_calls == [ - mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'feature_name', 'feature_name') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'feature_flag_name', 'feature_flag_name') ] _logger.reset_mock() assert client.get_treatment('some_key', True) == CONTROL assert _logger.error.mock_calls == [ - mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'feature_name', 'feature_name') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'feature_flag_name', 'feature_flag_name') ] _logger.reset_mock() assert client.get_treatment('some_key', []) == CONTROL assert _logger.error.mock_calls == [ - mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'feature_name', 'feature_name') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'feature_flag_name', 'feature_flag_name') ] _logger.reset_mock() assert client.get_treatment('some_key', '') == CONTROL assert _logger.error.mock_calls == [ - mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment', 'feature_name', 'feature_name') + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment', 'feature_flag_name', 'feature_flag_name') ] _logger.reset_mock() @@ -223,20 +244,22 @@ def _get_storage_mock(storage): _logger.reset_mock() assert client.get_treatment('matching_key', ' some_feature ', None) == 'default_treatment' assert _logger.warning.mock_calls == [ - mocker.call('%s: feature_name \'%s\' has extra whitespace, trimming.', 'get_treatment', ' some_feature ') + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatment', 'feature flag name', ' some_feature ') ] _logger.reset_mock() - storage_mock.get.return_value = None + storage_mock.fetch_many.return_value = {'some_feature': None} + mocker.patch('splitio.client.client._LOGGER', new=_logger) assert client.get_treatment('matching_key', 'some_feature', None) == CONTROL assert _logger.warning.mock_calls == [ mocker.call( "%s: you passed \"%s\" that does not exist in this environment, " - "please double check what Splits exist in the web console.", + "please double check what Feature flags exist in the Split user interface.", 'get_treatment', 'some_feature' ) ] + factory.destroy def test_get_treatment_with_config(self, mocker): """Test get_treatment validation.""" @@ -247,34 +270,47 @@ def test_get_treatment_with_config(self, mocker): conditions_mock = mocker.PropertyMock() conditions_mock.return_value = [] type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] def _configs(treatment): return '{"some": "property"}' if treatment == 'default_treatment' else None split_mock.get_configurations_for.side_effect = _configs storage_mock = mocker.Mock(spec=SplitStorage) - storage_mock.get.return_value = split_mock - - def _get_storage_mock(storage): - return { + storage_mock.fetch_many.return_value = {'some_feature': split_mock} + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorage) + rbs_storage.fetch_many.return_value = {} + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, mocker.Mock(spec=EventStorage), ImpressionStorage, telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + { 'splits': storage_mock, 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, 'impressions': mocker.Mock(spec=ImpressionStorage), 'events': mocker.Mock(spec=EventStorage), - }[storage] - factory_mock = mocker.Mock(spec=SplitFactory) - factory_mock._get_storage.side_effect = _get_storage_mock - factory_destroyed = mocker.PropertyMock() - factory_destroyed.return_value = False - factory_mock._waiting_fork.return_value = False - type(factory_mock).destroyed = factory_destroyed - - client = Client(factory_mock, mocker.Mock()) + }, + mocker.Mock(), + recorder, + mocker.Mock(), + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + + client = Client(factory, mocker.Mock(), mocker.Mock(), mocker.Mock(), FallbackTreatmentCalculator(None)) _logger = mocker.Mock() mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) assert client.get_treatment_with_config(None, 'some_feature') == (CONTROL, None) assert _logger.error.mock_calls == [ - mocker.call('%s: you passed a null key, key must be a non-empty string.', 'get_treatment_with_config') + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment_with_config', 'key', 'key') ] _logger.reset_mock() @@ -323,31 +359,31 @@ def _get_storage_mock(storage): _logger.reset_mock() assert client.get_treatment_with_config('some_key', None) == (CONTROL, None) assert _logger.error.mock_calls == [ - mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_name', 'feature_name') + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_flag_name', 'feature_flag_name') ] _logger.reset_mock() assert client.get_treatment_with_config('some_key', 123) == (CONTROL, None) assert _logger.error.mock_calls == [ - mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_name', 'feature_name') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_flag_name', 'feature_flag_name') ] _logger.reset_mock() assert client.get_treatment_with_config('some_key', True) == (CONTROL, None) assert _logger.error.mock_calls == [ - mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_name', 'feature_name') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_flag_name', 'feature_flag_name') ] _logger.reset_mock() assert client.get_treatment_with_config('some_key', []) == (CONTROL, None) assert _logger.error.mock_calls == [ - mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_name', 'feature_name') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_flag_name', 'feature_flag_name') ] _logger.reset_mock() assert client.get_treatment_with_config('some_key', '') == (CONTROL, None) assert _logger.error.mock_calls == [ - mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_name', 'feature_name') + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_flag_name', 'feature_flag_name') ] _logger.reset_mock() @@ -451,34 +487,36 @@ def _get_storage_mock(storage): _logger.reset_mock() assert client.get_treatment_with_config('matching_key', ' some_feature ', None) == ('default_treatment', '{"some": "property"}') assert _logger.warning.mock_calls == [ - mocker.call('%s: feature_name \'%s\' has extra whitespace, trimming.', 'get_treatment_with_config', ' some_feature ') + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatment_with_config', 'feature flag name', ' some_feature ') ] _logger.reset_mock() - storage_mock.get.return_value = None + storage_mock.fetch_many.return_value = {'some_feature': None} + mocker.patch('splitio.client.client._LOGGER', new=_logger) assert client.get_treatment_with_config('matching_key', 'some_feature', None) == (CONTROL, None) assert _logger.warning.mock_calls == [ mocker.call( "%s: you passed \"%s\" that does not exist in this environment, " - "please double check what Splits exist in the web console.", + "please double check what Feature flags exist in the Split user interface.", 'get_treatment_with_config', 'some_feature' ) ] + factory.destroy def test_valid_properties(self, mocker): """Test valid_properties() method.""" - assert input_validator.valid_properties(None) == (True, None, 1024) - assert input_validator.valid_properties([]) == (False, None, 0) - assert input_validator.valid_properties(True) == (False, None, 0) - assert input_validator.valid_properties(dict()) == (True, None, 1024) - assert input_validator.valid_properties({2: 123}) == (True, None, 1024) + assert input_validator.valid_properties(None, '') == (True, None, 1024) + assert input_validator.valid_properties([], '') == (False, None, 0) + assert input_validator.valid_properties(True, '') == (False, None, 0) + assert input_validator.valid_properties(dict(), '') == (True, None, 1024) + assert input_validator.valid_properties({2: 123}, '') == (True, None, 1024) class Test: pass assert input_validator.valid_properties({ "test": Test() - }) == (True, {"test": None}, 1028) + }, '') == (True, {"test": None}, 1028) props1 = { "test1": "test", @@ -488,7 +526,7 @@ class Test: "test5": [], 2: "t", } - r1, r2, r3 = input_validator.valid_properties(props1) + r1, r2, r3 = input_validator.valid_properties(props1, '') assert r1 is True assert len(r2.keys()) == 5 assert r2["test1"] == "test" @@ -501,12 +539,12 @@ class Test: props2 = dict() for i in range(301): props2[str(i)] = i - assert input_validator.valid_properties(props2) == (True, props2, 1817) + assert input_validator.valid_properties(props2, '') == (True, props2, 1817) props3 = dict() for i in range(100, 210): props3["prop" + str(i)] = "a" * 300 - r1, r2, r3 = input_validator.valid_properties(props3) + r1, r2, r3 = input_validator.valid_properties(props3, '') assert r1 is False assert r3 == 32952 @@ -514,17 +552,37 @@ def test_track(self, mocker): """Test track method().""" events_storage_mock = mocker.Mock(spec=EventStorage) events_storage_mock.put.return_value = True - factory_mock = mocker.Mock(spec=SplitFactory) - factory_destroyed = mocker.PropertyMock() - factory_destroyed.return_value = False - factory_mock._waiting_fork.return_value = False - type(factory_mock).destroyed = factory_destroyed - factory_mock._apikey = 'some-test' - event_storage = mocker.Mock(spec=EventStorage) event_storage.put.return_value = True - recorder = StandardRecorder(mocker.Mock(), event_storage, mocker.Mock()) - client = Client(factory_mock, recorder) + split_storage_mock = mocker.Mock(spec=SplitStorage) + split_storage_mock.is_valid_traffic_type.return_value = True + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, events_storage_mock, ImpressionStorage, telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + { + 'splits': split_storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': mocker.Mock(spec=RuleBasedSegmentsStorage), + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': events_storage_mock, + }, + mocker.Mock(), + recorder, + mocker.Mock(), + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + factory._sdk_key = 'some-test' + + client = Client(factory, recorder, mocker.Mock(), mocker.Mock(), FallbackTreatmentCalculator(None)) client._event_storage = event_storage _logger = mocker.Mock() mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) @@ -598,9 +656,10 @@ def test_track(self, mocker): _logger.reset_mock() assert client.track("some_key", "TRAFFIC_type", "event_type", 1) is True assert _logger.warning.mock_calls == [ - mocker.call("track: %s should be all lowercase - converting string to lowercase.", 'TRAFFIC_type') + mocker.call("%s: %s '%s' should be all lowercase - converting string to lowercase", 'track', 'traffic type', 'TRAFFIC_type') ] + _logger.reset_mock() assert client.track("some_key", "traffic_type", None, 1) is False assert _logger.error.mock_calls == [ mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'track', 'event_type', 'event_type') @@ -633,12 +692,12 @@ def test_track(self, mocker): _logger.reset_mock() assert client.track("some_key", "traffic_type", "@@", 1) is False assert _logger.error.mock_calls == [ - mocker.call("%s: you passed %s, event_type must adhere to the regular " + mocker.call("%s: you passed %s, %s must adhere to the regular " "expression %s. This means " - "an event name must be alphanumeric, cannot be more than 80 " + "%s must be alphanumeric, cannot be more than %s " "characters long, and can only include a dash, underscore, " "period, or colon as separators of alphanumeric characters.", - 'track', '@@', '^[a-zA-Z0-9][-_.:a-zA-Z0-9]{0,79}$') + 'track', '@@', 'an event name', '^[a-zA-Z0-9][-_.:a-zA-Z0-9]{0,79}$', 'an event name', 80) ] _logger.reset_mock() @@ -674,11 +733,9 @@ def test_track(self, mocker): # Test traffic type existance ready_property = mocker.PropertyMock() ready_property.return_value = True - type(factory_mock).ready = ready_property + type(factory).ready = ready_property - split_storage_mock = mocker.Mock(spec=SplitStorage) - split_storage_mock.is_valid_traffic_type.return_value = True - factory_mock._get_storage.return_value = split_storage_mock +# factory._get_storage.return_value = split_storage_mock # Test that it doesn't warn if tt is cached, not in localhost mode and sdk is ready _logger.reset_mock() @@ -692,23 +749,23 @@ def test_track(self, mocker): assert client.track("some_key", "traffic_type", "event_type", None) is True assert _logger.error.mock_calls == [] assert _logger.warning.mock_calls == [mocker.call( - 'track: Traffic Type %s does not have any corresponding Splits in this environment, ' + 'track: Traffic Type %s does not have any corresponding Feature flags in this environment, ' 'make sure you\'re tracking your events to a valid traffic type defined ' - 'in the Split console.', + 'in the Split user interface.', 'traffic_type' )] # Test that it does not warn when in localhost mode. - factory_mock._apikey = 'localhost' + factory._sdk_key = 'localhost' _logger.reset_mock() assert client.track("some_key", "traffic_type", "event_type", None) is True assert _logger.error.mock_calls == [] assert _logger.warning.mock_calls == [] # Test that it does not warn when not in localhost mode and not ready - factory_mock._apikey = 'not-localhost' + factory._sdk_key = 'not-localhost' ready_property.return_value = False - type(factory_mock).ready = ready_property + type(factory).ready = ready_property _logger.reset_mock() assert client.track("some_key", "traffic_type", "event_type", None) is True assert _logger.error.mock_calls == [] @@ -718,14 +775,14 @@ def test_track(self, mocker): _logger.reset_mock() assert client.track("some_key", "traffic_type", "event_type", 1, []) is False assert _logger.error.mock_calls == [ - mocker.call("track: properties must be of type dictionary.") + mocker.call("%s: properties must be of type dictionary.", "track") ] # Test track with invalid properties _logger.reset_mock() assert client.track("some_key", "traffic_type", "event_type", 1, True) is False assert _logger.error.mock_calls == [ - mocker.call("track: properties must be of type dictionary.") + mocker.call("%s: properties must be of type dictionary.", "track") ] # Test track with properties @@ -740,7 +797,7 @@ def test_track(self, mocker): _logger.reset_mock() assert client.track("some_key", "traffic_type", "event_type", 1, props1) is True assert _logger.warning.mock_calls == [ - mocker.call("Property %s is of invalid type. Setting value to None", []) + mocker.call("%s: Property %s is of invalid type. Setting value to None", "track", []) ] # Test track with more than 300 properties @@ -750,7 +807,7 @@ def test_track(self, mocker): _logger.reset_mock() assert client.track("some_key", "traffic_type", "event_type", 1, props2) is True assert _logger.warning.mock_calls == [ - mocker.call("Event has more than 300 properties. Some of them will be trimmed when processed") + mocker.call("%s: Event has more than 300 properties. Some of them will be trimmed when processed", "track") ] # Test track with properties higher than 32kb @@ -760,8 +817,9 @@ def test_track(self, mocker): props3["prop" + str(i)] = "a" * 300 assert client.track("some_key", "traffic_type", "event_type", 1, props3) is False assert _logger.error.mock_calls == [ - mocker.call("The maximum size allowed for the properties is 32768 bytes. Current one is 32952 bytes. Event not queued") + mocker.call("%s: The maximum size allowed for the properties is 32768 bytes. Current one is 32952 bytes. Event not queued", "track") ] + factory.destroy def test_get_treatments(self, mocker): """Test getTreatments() method.""" @@ -772,34 +830,49 @@ def test_get_treatments(self, mocker): conditions_mock = mocker.PropertyMock() conditions_mock.return_value = [] type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] storage_mock = mocker.Mock(spec=SplitStorage) storage_mock.fetch_many.return_value = { - 'some_feature': split_mock, - 'some': split_mock, + 'some_feature': split_mock } - - def _get_storage_mock(storage): - return { + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorage) + rbs_storage.fetch_many.return_value = {} + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, mocker.Mock(spec=EventStorage), mocker.Mock(spec=ImpressionStorage), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + { 'splits': storage_mock, 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, 'impressions': mocker.Mock(spec=ImpressionStorage), 'events': mocker.Mock(spec=EventStorage), - }[storage] - factory_mock = mocker.Mock(spec=SplitFactory) - factory_mock._get_storage.side_effect = _get_storage_mock - factory_destroyed = mocker.PropertyMock() - factory_destroyed.return_value = False - factory_mock._waiting_fork.return_value = False - type(factory_mock).destroyed = factory_destroyed - - client = Client(factory_mock, mocker.Mock()) + }, + mocker.Mock(), + recorder, + mocker.Mock(), + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + + client = Client(factory, recorder, mocker.Mock(), mocker.Mock(), FallbackTreatmentCalculator(None)) _logger = mocker.Mock() mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) assert client.get_treatments(None, ['some_feature']) == {'some_feature': CONTROL} assert _logger.error.mock_calls == [ - mocker.call('%s: you passed a null key, key must be a non-empty string.', 'get_treatments') + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatments', 'key', 'key') ] _logger.reset_mock() @@ -815,6 +888,7 @@ def _get_storage_mock(storage): mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatments', 'key', 250) ] + split_mock.name = 'some_feature' _logger.reset_mock() assert client.get_treatments(12345, ['some_feature']) == {'some_feature': 'default_treatment'} assert _logger.warning.mock_calls == [ @@ -836,64 +910,66 @@ def _get_storage_mock(storage): _logger.reset_mock() assert client.get_treatments('some_key', None) == {} assert _logger.error.mock_calls == [ - mocker.call('%s: feature_names must be a non-empty array.', 'get_treatments') + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments') ] _logger.reset_mock() assert client.get_treatments('some_key', True) == {} assert _logger.error.mock_calls == [ - mocker.call('%s: feature_names must be a non-empty array.', 'get_treatments') + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments') ] _logger.reset_mock() assert client.get_treatments('some_key', 'some_string') == {} assert _logger.error.mock_calls == [ - mocker.call('%s: feature_names must be a non-empty array.', 'get_treatments') + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments') ] _logger.reset_mock() assert client.get_treatments('some_key', []) == {} assert _logger.error.mock_calls == [ - mocker.call('%s: feature_names must be a non-empty array.', 'get_treatments') + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments') ] _logger.reset_mock() assert client.get_treatments('some_key', [None, None]) == {} assert _logger.error.mock_calls == [ - mocker.call('%s: feature_names must be a non-empty array.', 'get_treatments') + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments') ] _logger.reset_mock() assert client.get_treatments('some_key', [True]) == {} - assert mocker.call('%s: feature_names must be a non-empty array.', 'get_treatments') in _logger.error.mock_calls + assert mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments') in _logger.error.mock_calls _logger.reset_mock() assert client.get_treatments('some_key', ['', '']) == {} - assert mocker.call('%s: feature_names must be a non-empty array.', 'get_treatments') in _logger.error.mock_calls + assert mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments') in _logger.error.mock_calls _logger.reset_mock() - assert client.get_treatments('some_key', ['some ']) == {'some': 'default_treatment'} + assert client.get_treatments('some_key', ['some_feature ']) == {'some_feature': 'default_treatment'} assert _logger.warning.mock_calls == [ - mocker.call('%s: feature_name \'%s\' has extra whitespace, trimming.', 'get_treatments', 'some ') + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatments', 'feature flag name', 'some_feature ') ] _logger.reset_mock() storage_mock.fetch_many.return_value = { 'some_feature': None } - storage_mock.get.return_value = None + storage_mock.fetch_many.return_value = {'some_feature': None} ready_mock = mocker.PropertyMock() ready_mock.return_value = True - type(factory_mock).ready = ready_mock + type(factory).ready = ready_mock + mocker.patch('splitio.client.client._LOGGER', new=_logger) assert client.get_treatments('matching_key', ['some_feature'], None) == {'some_feature': CONTROL} assert _logger.warning.mock_calls == [ mocker.call( "%s: you passed \"%s\" that does not exist in this environment, " - "please double check what Splits exist in the web console.", + "please double check what Feature flags exist in the Split user interface.", 'get_treatments', 'some_feature' ) ] + factory.destroy def test_get_treatments_with_config(self, mocker): """Test getTreatments() method.""" @@ -904,29 +980,51 @@ def test_get_treatments_with_config(self, mocker): conditions_mock = mocker.PropertyMock() conditions_mock.return_value = [] type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] storage_mock = mocker.Mock(spec=SplitStorage) storage_mock.fetch_many.return_value = { 'some_feature': split_mock } + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorage) + rbs_storage.fetch_many.return_value = {} + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, mocker.Mock(spec=EventStorage), mocker.Mock(spec=ImpressionStorage), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + mocker.Mock(), + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + split_mock.name = 'some_feature' - factory_mock = mocker.Mock(spec=SplitFactory) - factory_mock._get_storage.return_value = storage_mock - factory_destroyed = mocker.PropertyMock() - factory_destroyed.return_value = False - factory_mock._waiting_fork.return_value = False - type(factory_mock).destroyed = factory_destroyed def _configs(treatment): return '{"some": "property"}' if treatment == 'default_treatment' else None split_mock.get_configurations_for.side_effect = _configs - client = Client(factory_mock, mocker.Mock()) + client = Client(factory, mocker.Mock(), mocker.Mock(), mocker.Mock(), FallbackTreatmentCalculator(None)) _logger = mocker.Mock() mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) assert client.get_treatments_with_config(None, ['some_feature']) == {'some_feature': (CONTROL, None)} assert _logger.error.mock_calls == [ - mocker.call('%s: you passed a null key, key must be a non-empty string.', 'get_treatments_with_config') + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatments_with_config', 'key', 'key') ] _logger.reset_mock() @@ -963,45 +1061,45 @@ def _configs(treatment): _logger.reset_mock() assert client.get_treatments_with_config('some_key', None) == {} assert _logger.error.mock_calls == [ - mocker.call('%s: feature_names must be a non-empty array.', 'get_treatments_with_config') + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments_with_config') ] _logger.reset_mock() assert client.get_treatments_with_config('some_key', True) == {} assert _logger.error.mock_calls == [ - mocker.call('%s: feature_names must be a non-empty array.', 'get_treatments_with_config') + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments_with_config') ] _logger.reset_mock() assert client.get_treatments_with_config('some_key', 'some_string') == {} assert _logger.error.mock_calls == [ - mocker.call('%s: feature_names must be a non-empty array.', 'get_treatments_with_config') + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments_with_config') ] _logger.reset_mock() assert client.get_treatments_with_config('some_key', []) == {} assert _logger.error.mock_calls == [ - mocker.call('%s: feature_names must be a non-empty array.', 'get_treatments_with_config') + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments_with_config') ] _logger.reset_mock() assert client.get_treatments_with_config('some_key', [None, None]) == {} assert _logger.error.mock_calls == [ - mocker.call('%s: feature_names must be a non-empty array.', 'get_treatments_with_config') + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments_with_config') ] _logger.reset_mock() assert client.get_treatments_with_config('some_key', [True]) == {} - assert mocker.call('%s: feature_names must be a non-empty array.', 'get_treatments_with_config') in _logger.error.mock_calls + assert mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments_with_config') in _logger.error.mock_calls _logger.reset_mock() assert client.get_treatments_with_config('some_key', ['', '']) == {} - assert mocker.call('%s: feature_names must be a non-empty array.', 'get_treatments_with_config') in _logger.error.mock_calls + assert mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments_with_config') in _logger.error.mock_calls _logger.reset_mock() assert client.get_treatments_with_config('some_key', ['some_feature ']) == {'some_feature': ('default_treatment', '{"some": "property"}')} assert _logger.warning.mock_calls == [ - mocker.call('%s: feature_name \'%s\' has extra whitespace, trimming.', 'get_treatments_with_config', 'some_feature ') + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatments_with_config', 'feature flag name', 'some_feature ') ] _logger.reset_mock() @@ -1011,104 +1109,2752 @@ def _configs(treatment): storage_mock.get.return_value = None ready_mock = mocker.PropertyMock() ready_mock.return_value = True - type(factory_mock).ready = ready_mock + type(factory).ready = ready_mock + mocker.patch('splitio.client.client._LOGGER', new=_logger) assert client.get_treatments('matching_key', ['some_feature'], None) == {'some_feature': CONTROL} assert _logger.warning.mock_calls == [ mocker.call( "%s: you passed \"%s\" that does not exist in this environment, " - "please double check what Splits exist in the web console.", + "please double check what Feature flags exist in the Split user interface.", 'get_treatments', 'some_feature' ) ] + factory.destroy + + def test_get_treatments_by_flag_set(self, mocker): + """Test getTreatments() method.""" + split_mock = mocker.Mock(spec=Split) + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] + storage_mock = mocker.Mock(spec=InMemorySplitStorage) + storage_mock.fetch_many.return_value = { + 'some_feature': split_mock + } + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorage) + rbs_storage.fetch_many.return_value = {} + storage_mock.get_feature_flags_by_sets.return_value = ['some_feature'] + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, mocker.Mock(spec=EventStorage), mocker.Mock(spec=ImpressionStorage), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + mocker.Mock(), + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + client = Client(factory, recorder, mocker.Mock(), mocker.Mock(), FallbackTreatmentCalculator(None)) + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + assert client.get_treatments_by_flag_set(None, 'some_set') == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatments_by_flag_set', 'key', 'key') + ] -class ManagerInputValidationTests(object): #pylint: disable=too-few-public-methods - """Manager input validation test cases.""" + _logger.reset_mock() + assert client.get_treatments_by_flag_set("", 'some_set') == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatments_by_flag_set', 'key', 'key') + ] - def test_split_(self, mocker): - """Test split input validation.""" - storage_mock = mocker.Mock(spec=SplitStorage) + key = ''.join('a' for _ in range(0, 255)) + _logger.reset_mock() + assert client.get_treatments_by_flag_set(key, 'some_set') == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatments_by_flag_set', 'key', 250) + ] + + split_mock.name = 'some_feature' + _logger.reset_mock() + assert client.get_treatments_by_flag_set(12345, 'some_set') == {'some_feature': 'default_treatment'} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatments_by_flag_set', 'key', 12345) + ] + + _logger.reset_mock() + assert client.get_treatments_by_flag_set(True, 'some_set') == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_by_flag_set', 'key', 'key') + ] + + _logger.reset_mock() + assert client.get_treatments_by_flag_set([], 'some_set') == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_by_flag_set', 'key', 'key') + ] + + _logger.reset_mock() + client.get_treatments_by_flag_set('some_key', None) + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", + 'get_treatments_by_flag_set', 'flag set', 'flag set') + ] + + _logger.reset_mock() + client.get_treatments_by_flag_set('some_key', '$$') + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed %s, %s must adhere to the regular " + "expression %s. This means " + "%s must be alphanumeric, cannot be more than %s " + "characters long, and can only include a dash, underscore, " + "period, or colon as separators of alphanumeric characters.", + 'get_treatments_by_flag_set', '$$', 'a flag set', '^[a-z0-9][_a-z0-9]{0,49}$', 'a flag set', 50) + ] + + _logger.reset_mock() + assert client.get_treatments_by_flag_set('some_key', 'some_set ') == {'some_feature': 'default_treatment'} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatments_by_flag_set', 'flag set', 'some_set ') + ] + + _logger.reset_mock() + storage_mock.get_feature_flags_by_sets.return_value = [] + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + mocker.patch('splitio.client.client._LOGGER', new=_logger) + assert client.get_treatments_by_flag_set('matching_key', 'some_set') == {} + assert _logger.warning.mock_calls == [ + mocker.call("%s: No valid Flag set or no feature flags found for evaluating treatments", "get_treatments_by_flag_set") + ] + factory.destroy + + def test_get_treatments_by_flag_sets(self, mocker): + """Test getTreatments() method.""" split_mock = mocker.Mock(spec=Split) - storage_mock.get.return_value = split_mock - factory_mock = mocker.Mock(spec=SplitFactory) - factory_mock._get_storage.return_value = storage_mock - factory_destroyed = mocker.PropertyMock() - factory_destroyed.return_value = False - factory_mock._waiting_fork.return_value = False - type(factory_mock).destroyed = factory_destroyed - - manager = SplitManager(factory_mock) + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] + storage_mock = mocker.Mock(spec=InMemorySplitStorage) + storage_mock.fetch_many.return_value = { + 'some_feature': split_mock + } + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorage) + rbs_storage.fetch_many.return_value = {} + storage_mock.get_feature_flags_by_sets.return_value = ['some_feature'] + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, mocker.Mock(spec=EventStorage), mocker.Mock(spec=ImpressionStorage), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + mocker.Mock(), + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + + client = Client(factory, recorder, mocker.Mock(), mocker.Mock(), FallbackTreatmentCalculator(None)) _logger = mocker.Mock() mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) - assert manager.split(None) is None + assert client.get_treatments_by_flag_sets(None, ['some_set']) == {'some_feature': CONTROL} assert _logger.error.mock_calls == [ - mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'split', 'feature_name', 'feature_name') + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatments_by_flag_sets', 'key', 'key') ] _logger.reset_mock() - assert manager.split("") is None + assert client.get_treatments_by_flag_sets("", ['some_set']) == {'some_feature': CONTROL} assert _logger.error.mock_calls == [ - mocker.call("%s: you passed an empty %s, %s must be a non-empty string.", 'split', 'feature_name', 'feature_name') + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatments_by_flag_sets', 'key', 'key') ] + key = ''.join('a' for _ in range(0, 255)) _logger.reset_mock() - assert manager.split(True) is None + assert client.get_treatments_by_flag_sets(key, ['some_set']) == {'some_feature': CONTROL} assert _logger.error.mock_calls == [ - mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'split', 'feature_name', 'feature_name') + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatments_by_flag_sets', 'key', 250) ] + split_mock.name = 'some_feature' _logger.reset_mock() - assert manager.split([]) is None + assert client.get_treatments_by_flag_sets(12345, ['some_set']) == {'some_feature': 'default_treatment'} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatments_by_flag_sets', 'key', 12345) + ] + + _logger.reset_mock() + assert client.get_treatments_by_flag_sets(True, ['some_set']) == {'some_feature': CONTROL} assert _logger.error.mock_calls == [ - mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'split', 'feature_name', 'feature_name') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_by_flag_sets', 'key', 'key') ] _logger.reset_mock() - manager.split('some_split') - assert split_mock.to_split_view.mock_calls == [mocker.call()] - assert _logger.error.mock_calls == [] + assert client.get_treatments_by_flag_sets([], ['some_set']) == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_by_flag_sets', 'key', 'key') + ] _logger.reset_mock() - split_mock.reset_mock() - storage_mock.get.return_value = None - manager.split('nonexistant-split') - assert split_mock.to_split_view.mock_calls == [] - assert _logger.warning.mock_calls == [mocker.call( - "split: you passed \"%s\" that does not exist in this environment, " - "please double check what Splits exist in the web console.", - 'nonexistant-split' - )] + client.get_treatments_by_flag_sets('some_key', None) + assert _logger.warning.mock_calls == [ + mocker.call("%s: flag sets parameter type should be list object, parameter is discarded", "get_treatments_by_flag_sets") + ] -class FactoryInputValidationTests(object): #pylint: disable=too-few-public-methods - """Factory instantiation input validation test cases.""" + _logger.reset_mock() + client.get_treatments_by_flag_sets('some_key', [None]) + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", + 'get_treatments_by_flag_sets', 'flag set', 'flag set') + ] - def test_input_validation_factory(self, mocker): - """Test the input validators for factory instantiation.""" - logger = mocker.Mock(spec=logging.Logger) - mocker.patch('splitio.client.input_validator._LOGGER', new=logger) + _logger.reset_mock() + client.get_treatments_by_flag_sets('some_key', ['$$']) + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed %s, %s must adhere to the regular " + "expression %s. This means " + "%s must be alphanumeric, cannot be more than %s " + "characters long, and can only include a dash, underscore, " + "period, or colon as separators of alphanumeric characters.", + 'get_treatments_by_flag_sets', '$$', 'a flag set', '^[a-z0-9][_a-z0-9]{0,49}$', 'a flag set', 50) + ] - assert get_factory(None) is None - assert logger.error.mock_calls == [ - mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'factory_instantiation', 'apikey', 'apikey') + _logger.reset_mock() + assert client.get_treatments_by_flag_sets('some_key', ['some_set ']) == {'some_feature': 'default_treatment'} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatments_by_flag_sets', 'flag set', 'some_set ') ] - logger.reset_mock() - assert get_factory('') is None - assert logger.error.mock_calls == [ - mocker.call("%s: you passed an empty %s, %s must be a non-empty string.", 'factory_instantiation', 'apikey', 'apikey') + _logger.reset_mock() + storage_mock.get_feature_flags_by_sets.return_value = [] + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + mocker.patch('splitio.client.client._LOGGER', new=_logger) + assert client.get_treatments_by_flag_sets('matching_key', ['some_set']) == {} + assert _logger.warning.mock_calls == [ + mocker.call("%s: No valid Flag set or no feature flags found for evaluating treatments", "get_treatments_by_flag_sets") ] + factory.destroy - logger.reset_mock() - assert get_factory(True) is None - assert logger.error.mock_calls == [ - mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'factory_instantiation', 'apikey', 'apikey') + def test_get_treatments_with_config_by_flag_set(self, mocker): + split_mock = mocker.Mock(spec=Split) + def _configs(treatment): + return '{"some": "property"}' if treatment == 'default_treatment' else None + split_mock.get_configurations_for.side_effect = _configs + split_mock.name = 'some_feature' + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] + storage_mock = mocker.Mock(spec=InMemorySplitStorage) + storage_mock.fetch_many.return_value = { + 'some_feature': split_mock + } + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorage) + rbs_storage.fetch_many.return_value = {} + + storage_mock.get_feature_flags_by_sets.return_value = ['some_feature'] + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, mocker.Mock(spec=EventStorage), mocker.Mock(spec=ImpressionStorage), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + mocker.Mock(), + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + + client = Client(factory, recorder, mocker.Mock(), mocker.Mock(), FallbackTreatmentCalculator(None)) + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + + assert client.get_treatments_with_config_by_flag_set(None, 'some_set') == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_set', 'key', 'key') ] - logger.reset_mock() - f = get_factory(True, config={'redisHost': 'some-host'}) - assert f is not None - assert logger.error.mock_calls == [] - f.destroy() + _logger.reset_mock() + assert client.get_treatments_with_config_by_flag_set("", 'some_set') == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_set', 'key', 'key') + ] + + key = ''.join('a' for _ in range(0, 255)) + _logger.reset_mock() + assert client.get_treatments_with_config_by_flag_set(key, 'some_set') == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatments_with_config_by_flag_set', 'key', 250) + ] + + split_mock.name = 'some_feature' + _logger.reset_mock() + assert client.get_treatments_with_config_by_flag_set(12345, 'some_set') == {'some_feature': ('default_treatment', '{"some": "property"}')} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatments_with_config_by_flag_set', 'key', 12345) + ] + + _logger.reset_mock() + assert client.get_treatments_with_config_by_flag_set(True, 'some_set') == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_set', 'key', 'key') + ] + + _logger.reset_mock() + assert client.get_treatments_with_config_by_flag_set([], 'some_set') == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_set', 'key', 'key') + ] + + _logger.reset_mock() + client.get_treatments_with_config_by_flag_set('some_key', None) + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", + 'get_treatments_with_config_by_flag_set', 'flag set', 'flag set') + ] + + _logger.reset_mock() + client.get_treatments_with_config_by_flag_set('some_key', '$$') + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed %s, %s must adhere to the regular " + "expression %s. This means " + "%s must be alphanumeric, cannot be more than %s " + "characters long, and can only include a dash, underscore, " + "period, or colon as separators of alphanumeric characters.", + 'get_treatments_with_config_by_flag_set', '$$', 'a flag set', '^[a-z0-9][_a-z0-9]{0,49}$', 'a flag set', 50) + ] + + _logger.reset_mock() + assert client.get_treatments_with_config_by_flag_set('some_key', 'some_set ') == {'some_feature': ('default_treatment', '{"some": "property"}')} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatments_with_config_by_flag_set', 'flag set', 'some_set ') + ] + + _logger.reset_mock() + storage_mock.get_feature_flags_by_sets.return_value = [] + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + mocker.patch('splitio.client.client._LOGGER', new=_logger) + assert client.get_treatments_with_config_by_flag_set('matching_key', 'some_set') == {} + assert _logger.warning.mock_calls == [ + mocker.call("%s: No valid Flag set or no feature flags found for evaluating treatments", "get_treatments_with_config_by_flag_set") + ] + factory.destroy + + def test_get_treatments_with_config_by_flag_sets(self, mocker): + split_mock = mocker.Mock(spec=Split) + def _configs(treatment): + return '{"some": "property"}' if treatment == 'default_treatment' else None + split_mock.get_configurations_for.side_effect = _configs + split_mock.name = 'some_feature' + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] + storage_mock = mocker.Mock(spec=InMemorySplitStorage) + storage_mock.fetch_many.return_value = { + 'some_feature': split_mock + } + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorage) + rbs_storage.fetch_many.return_value = {} + + storage_mock.get_feature_flags_by_sets.return_value = ['some_feature'] + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, mocker.Mock(spec=EventStorage), mocker.Mock(spec=ImpressionStorage), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + mocker.Mock(), + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + + client = Client(factory, recorder, mocker.Mock(), mocker.Mock(), FallbackTreatmentCalculator(None)) + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + + assert client.get_treatments_with_config_by_flag_sets(None, ['some_set']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_sets', 'key', 'key') + ] + + _logger.reset_mock() + assert client.get_treatments_with_config_by_flag_sets("", ['some_set']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_sets', 'key', 'key') + ] + + key = ''.join('a' for _ in range(0, 255)) + _logger.reset_mock() + assert client.get_treatments_with_config_by_flag_sets(key, ['some_set']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatments_with_config_by_flag_sets', 'key', 250) + ] + + split_mock.name = 'some_feature' + _logger.reset_mock() + assert client.get_treatments_with_config_by_flag_sets(12345, ['some_set']) == {'some_feature': ('default_treatment', '{"some": "property"}')} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatments_with_config_by_flag_sets', 'key', 12345) + ] + + _logger.reset_mock() + assert client.get_treatments_with_config_by_flag_sets(True, ['some_set']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_sets', 'key', 'key') + ] + + _logger.reset_mock() + assert client.get_treatments_with_config_by_flag_sets([], ['some_set']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_sets', 'key', 'key') + ] + + _logger.reset_mock() + client.get_treatments_with_config_by_flag_sets('some_key', [None]) + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", + 'get_treatments_with_config_by_flag_sets', 'flag set', 'flag set') + ] + + _logger.reset_mock() + client.get_treatments_with_config_by_flag_sets('some_key', ['$$']) + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed %s, %s must adhere to the regular " + "expression %s. This means " + "%s must be alphanumeric, cannot be more than %s " + "characters long, and can only include a dash, underscore, " + "period, or colon as separators of alphanumeric characters.", + 'get_treatments_with_config_by_flag_sets', '$$', 'a flag set', '^[a-z0-9][_a-z0-9]{0,49}$', 'a flag set', 50) + ] + + _logger.reset_mock() + assert client.get_treatments_with_config_by_flag_sets('some_key', ['some_set ']) == {'some_feature': ('default_treatment', '{"some": "property"}')} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatments_with_config_by_flag_sets', 'flag set', 'some_set ') + ] + + _logger.reset_mock() + storage_mock.get_feature_flags_by_sets.return_value = [] + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + mocker.patch('splitio.client.client._LOGGER', new=_logger) + assert client.get_treatments_with_config_by_flag_sets('matching_key', ['some_set']) == {} + assert _logger.warning.mock_calls == [ + mocker.call("%s: No valid Flag set or no feature flags found for evaluating treatments", "get_treatments_with_config_by_flag_sets") + ] + factory.destroy + + def test_flag_sets_validation(self): + """Test sanitization for flag sets.""" + flag_sets = input_validator.validate_flag_sets([' set1', 'set2 ', 'set3'], 'method') + assert sorted(flag_sets) == ['set1', 'set2', 'set3'] + + flag_sets = input_validator.validate_flag_sets(['1set', '_set2'], 'method') + assert flag_sets == ['1set'] + + flag_sets = input_validator.validate_flag_sets(['Set1', 'SET2'], 'method') + assert sorted(flag_sets) == ['set1', 'set2'] + + flag_sets = input_validator.validate_flag_sets(['se\t1', 's/et2', 's*et3', 's!et4', 'se@t5', 'se#t5', 'se$t5', 'se^t5', 'se%t5', 'se&t5'], 'method') + assert flag_sets == [] + + flag_sets = input_validator.validate_flag_sets(['set4', 'set1', 'set3', 'set1'], 'method') + assert sorted(flag_sets) == ['set1', 'set3', 'set4'] + + flag_sets = input_validator.validate_flag_sets(['w' * 50, 's' * 51], 'method') + assert flag_sets == ['w' * 50] + + flag_sets = input_validator.validate_flag_sets('set1', 'method') + assert flag_sets == [] + + flag_sets = input_validator.validate_flag_sets([12, 33], 'method') + assert flag_sets == [] + + def test_fallback_treatments(self, mocker): + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + + assert input_validator.validate_fallback_treatment(FallbackTreatment("on", {"prop":"val"})) + assert input_validator.validate_fallback_treatment(FallbackTreatment("on")) + + _logger.reset_mock() + assert not input_validator.validate_fallback_treatment(FallbackTreatment("on" * 100)) + assert _logger.warning.mock_calls == [ + mocker.call("Config: Fallback treatment size should not exceed %s characters", 100) + ] + + assert input_validator.validate_fallback_treatment(FallbackTreatment("on", {"prop" * 500:"val" * 500})) + + _logger.reset_mock() + assert not input_validator.validate_fallback_treatment(FallbackTreatment("on/c")) + assert _logger.warning.mock_calls == [ + mocker.call("Config: Fallback treatment should match regex %s", "^[0-9]+[.a-zA-Z0-9_-]*$|^[a-zA-Z]+[a-zA-Z0-9_-]*$") + ] + + _logger.reset_mock() + assert not input_validator.validate_fallback_treatment(FallbackTreatment("on$as")) + assert _logger.warning.mock_calls == [ + mocker.call("Config: Fallback treatment should match regex %s", "^[0-9]+[.a-zA-Z0-9_-]*$|^[a-zA-Z]+[a-zA-Z0-9_-]*$") + ] + + assert input_validator.validate_fallback_treatment(FallbackTreatment("on_c")) + assert input_validator.validate_fallback_treatment(FallbackTreatment("on_45-c")) + +class ClientInputValidationAsyncTests(object): + """Input validation test cases.""" + + @pytest.mark.asyncio + async def test_get_treatment(self, mocker): + """Test get_treatment validation.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + split_mock = mocker.Mock(spec=Split) + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] + storage_mock = mocker.Mock(spec=SplitStorage) + async def fetch_many(*_): + return { + 'some_feature': split_mock + } + storage_mock.fetch_many = fetch_many + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorageAsync) + async def fetch_many_rbs(*_): + return {} + rbs_storage.fetch_many = fetch_many_rbs + + async def get_change_number(*_): + return 1 + storage_mock.get_change_number = get_change_number + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + recorder = StandardRecorderAsync(impmanager, mocker.Mock(spec=EventStorage), ImpressionStorage, telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + impmanager, + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock(), + None + ) + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + + client = ClientAsync(factory, mocker.Mock(), events_manager, mocker.Mock(), FallbackTreatmentCalculator(None)) + + async def record_treatment_stats(*_): + pass + client._recorder.record_treatment_stats = record_treatment_stats + + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + + assert await client.get_treatment(None, 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatment('', 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment', 'key', 'key') + ] + + _logger.reset_mock() + key = ''.join('a' for _ in range(0, 255)) + assert await client.get_treatment(key, 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatment', 'key', 250) + ] + + _logger.reset_mock() + assert await client.get_treatment(12345, 'some_feature') == 'default_treatment' + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatment', 'key', 12345) + ] + + _logger.reset_mock() + assert await client.get_treatment(float('nan'), 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatment(float('inf'), 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatment(True, 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatment([], 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatment('some_key', None) == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert await client.get_treatment('some_key', 123) == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert await client.get_treatment('some_key', True) == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert await client.get_treatment('some_key', []) == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert await client.get_treatment('some_key', '') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert await client.get_treatment('some_key', 'some_feature') == 'default_treatment' + assert _logger.error.mock_calls == [] + assert _logger.warning.mock_calls == [] + + _logger.reset_mock() + assert await client.get_treatment(Key(None, 'bucketing_key'), 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment', 'matching_key', 'matching_key') + ] + + _logger.reset_mock() + assert await client.get_treatment(Key('', 'bucketing_key'), 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment', 'matching_key', 'matching_key') + ] + + _logger.reset_mock() + assert await client.get_treatment(Key(float('nan'), 'bucketing_key'), 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'matching_key', 'matching_key') + ] + + _logger.reset_mock() + assert await client.get_treatment(Key(float('inf'), 'bucketing_key'), 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'matching_key', 'matching_key') + ] + + _logger.reset_mock() + assert await client.get_treatment(Key(True, 'bucketing_key'), 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'matching_key', 'matching_key') + ] + + _logger.reset_mock() + assert await client.get_treatment(Key([], 'bucketing_key'), 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'matching_key', 'matching_key') + ] + + _logger.reset_mock() + assert await client.get_treatment(Key(12345, 'bucketing_key'), 'some_feature') == 'default_treatment' + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatment', 'matching_key', 12345) + ] + + _logger.reset_mock() + key = ''.join('a' for _ in range(0, 255)) + assert await client.get_treatment(Key(key, 'bucketing_key'), 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatment', 'matching_key', 250) + ] + + _logger.reset_mock() + assert await client.get_treatment(Key('matching_key', None), 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment', 'bucketing_key', 'bucketing_key') + ] + + _logger.reset_mock() + assert await client.get_treatment(Key('matching_key', True), 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'bucketing_key', 'bucketing_key') + ] + + _logger.reset_mock() + assert await client.get_treatment(Key('matching_key', []), 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'bucketing_key', 'bucketing_key') + ] + + _logger.reset_mock() + assert await client.get_treatment(Key('matching_key', ''), 'some_feature') == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment', 'bucketing_key', 'bucketing_key') + ] + + _logger.reset_mock() + assert await client.get_treatment(Key('matching_key', 12345), 'some_feature') == 'default_treatment' + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatment', 'bucketing_key', 12345) + ] + + _logger.reset_mock() + assert await client.get_treatment('matching_key', 'some_feature', True) == CONTROL + assert _logger.error.mock_calls == [ + mocker.call('%s: attributes must be of type dictionary.', 'get_treatment') + ] + + _logger.reset_mock() + assert await client.get_treatment('matching_key', 'some_feature', {'test': 'test'}) == 'default_treatment' + assert _logger.error.mock_calls == [] + + _logger.reset_mock() + assert await client.get_treatment('matching_key', 'some_feature', None) == 'default_treatment' + assert _logger.error.mock_calls == [] + + _logger.reset_mock() + assert await client.get_treatment('matching_key', ' some_feature ', None) == 'default_treatment' + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatment', 'feature flag name', ' some_feature ') + ] + + _logger.reset_mock() + async def fetch_many(*_): + return {'some_feature': None} + storage_mock.fetch_many = fetch_many + + mocker.patch('splitio.client.client._LOGGER', new=_logger) + assert await client.get_treatment('matching_key', 'some_feature', None) == CONTROL + assert _logger.warning.mock_calls == [ + mocker.call( + "%s: you passed \"%s\" that does not exist in this environment, " + "please double check what Feature flags exist in the Split user interface.", + 'get_treatment', + 'some_feature' + ) + ] + await factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatment_with_config(self, mocker): + """Test get_treatment validation.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + split_mock = mocker.Mock(spec=Split) + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] + + def _configs(treatment): + return '{"some": "property"}' if treatment == 'default_treatment' else None + split_mock.get_configurations_for.side_effect = _configs + storage_mock = mocker.Mock(spec=SplitStorage) + async def fetch_many(*_): + return { + 'some_feature': split_mock + } + storage_mock.fetch_many = fetch_many + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorageAsync) + async def fetch_many_rbs(*_): + return {} + rbs_storage.fetch_many = fetch_many_rbs + + async def get_change_number(*_): + return 1 + storage_mock.get_change_number = get_change_number + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + recorder = StandardRecorderAsync(impmanager, mocker.Mock(spec=EventStorage), ImpressionStorage, telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + impmanager, + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock(), + None + ) + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + + client = ClientAsync(factory, mocker.Mock(), events_manager, mocker.Mock(), FallbackTreatmentCalculator(None)) + async def record_treatment_stats(*_): + pass + client._recorder.record_treatment_stats = record_treatment_stats + + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + + assert await client.get_treatment_with_config(None, 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment_with_config', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config('', 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment_with_config', 'key', 'key') + ] + + _logger.reset_mock() + key = ''.join('a' for _ in range(0, 255)) + assert await client.get_treatment_with_config(key, 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatment_with_config', 'key', 250) + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(12345, 'some_feature') == ('default_treatment', '{"some": "property"}') + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatment_with_config', 'key', 12345) + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(float('nan'), 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(float('inf'), 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(True, 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config([], 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config('some_key', None) == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config('some_key', 123) == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config('some_key', True) == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config('some_key', []) == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config('some_key', '') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment_with_config', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config('some_key', 'some_feature') == ('default_treatment', '{"some": "property"}') + assert _logger.error.mock_calls == [] + assert _logger.warning.mock_calls == [] + + _logger.reset_mock() + assert await client.get_treatment_with_config(Key(None, 'bucketing_key'), 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment_with_config', 'matching_key', 'matching_key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(Key('', 'bucketing_key'), 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment_with_config', 'matching_key', 'matching_key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(Key(float('nan'), 'bucketing_key'), 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'matching_key', 'matching_key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(Key(float('inf'), 'bucketing_key'), 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'matching_key', 'matching_key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(Key(True, 'bucketing_key'), 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'matching_key', 'matching_key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(Key([], 'bucketing_key'), 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'matching_key', 'matching_key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(Key(12345, 'bucketing_key'), 'some_feature') == ('default_treatment', '{"some": "property"}') + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatment_with_config', 'matching_key', 12345) + ] + + _logger.reset_mock() + key = ''.join('a' for _ in range(0, 255)) + assert await client.get_treatment_with_config(Key(key, 'bucketing_key'), 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatment_with_config', 'matching_key', 250) + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(Key('matching_key', None), 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment_with_config', 'bucketing_key', 'bucketing_key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(Key('matching_key', True), 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'bucketing_key', 'bucketing_key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(Key('matching_key', []), 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment_with_config', 'bucketing_key', 'bucketing_key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(Key('matching_key', ''), 'some_feature') == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment_with_config', 'bucketing_key', 'bucketing_key') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config(Key('matching_key', 12345), 'some_feature') == ('default_treatment', '{"some": "property"}') + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatment_with_config', 'bucketing_key', 12345) + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config('matching_key', 'some_feature', True) == (CONTROL, None) + assert _logger.error.mock_calls == [ + mocker.call('%s: attributes must be of type dictionary.', 'get_treatment_with_config') + ] + + _logger.reset_mock() + assert await client.get_treatment_with_config('matching_key', 'some_feature', {'test': 'test'}) == ('default_treatment', '{"some": "property"}') + assert _logger.error.mock_calls == [] + + _logger.reset_mock() + assert await client.get_treatment_with_config('matching_key', 'some_feature', None) == ('default_treatment', '{"some": "property"}') + assert _logger.error.mock_calls == [] + + _logger.reset_mock() + assert await client.get_treatment_with_config('matching_key', ' some_feature ', None) == ('default_treatment', '{"some": "property"}') + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatment_with_config', 'feature flag name', ' some_feature ') + ] + + _logger.reset_mock() + async def fetch_many(*_): + return {'some_feature': None} + storage_mock.fetch_many = fetch_many + + mocker.patch('splitio.client.client._LOGGER', new=_logger) + assert await client.get_treatment_with_config('matching_key', 'some_feature', None) == (CONTROL, None) + assert _logger.warning.mock_calls == [ + mocker.call( + "%s: you passed \"%s\" that does not exist in this environment, " + "please double check what Feature flags exist in the Split user interface.", + 'get_treatment_with_config', + 'some_feature' + ) + ] + await factory.destroy() + + @pytest.mark.asyncio + async def test_track(self, mocker): + """Test track method().""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + events_storage_mock = mocker.Mock(spec=EventStorage) + async def put(*_): + return True + events_storage_mock.put = put + + event_storage = mocker.Mock(spec=EventStorage) + event_storage.put = put + split_storage_mock = mocker.Mock(spec=SplitStorage) + split_storage_mock.is_valid_traffic_type = put + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + recorder = StandardRecorderAsync(impmanager, events_storage_mock, ImpressionStorage, telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + { + 'splits': split_storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': mocker.Mock(spec=RuleBasedSegmentsStorage), + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': events_storage_mock, + }, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + impmanager, + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock(), + None + ) + factory._sdk_key = 'some-test' + + client = ClientAsync(factory, recorder, events_manager, mocker.Mock(), FallbackTreatmentCalculator(None)) + client._event_storage = event_storage + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + + assert await client.track(None, "traffic_type", "event_type", 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'track', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.track("", "traffic_type", "event_type", 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an empty %s, %s must be a non-empty string.", 'track', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.track(12345, "traffic_type", "event_type", 1) is True + assert _logger.warning.mock_calls == [ + mocker.call("%s: %s %s is not of type string, converting.", 'track', 'key', 12345) + ] + + _logger.reset_mock() + assert await client.track(True, "traffic_type", "event_type", 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.track([], "traffic_type", "event_type", 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'key', 'key') + ] + + _logger.reset_mock() + key = ''.join('a' for _ in range(0, 255)) + assert await client.track(key, "traffic_type", "event_type", 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: %s too long - must be %s characters or less.", 'track', 'key', 250) + ] + + _logger.reset_mock() + assert await client.track("some_key", None, "event_type", 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'track', 'traffic_type', 'traffic_type') + ] + + _logger.reset_mock() + assert await client.track("some_key", "", "event_type", 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an empty %s, %s must be a non-empty string.", 'track', 'traffic_type', 'traffic_type') + ] + + _logger.reset_mock() + assert await client.track("some_key", 12345, "event_type", 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'traffic_type', 'traffic_type') + ] + + _logger.reset_mock() + assert await client.track("some_key", True, "event_type", 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'traffic_type', 'traffic_type') + ] + + _logger.reset_mock() + assert await client.track("some_key", [], "event_type", 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'traffic_type', 'traffic_type') + ] + + _logger.reset_mock() + assert await client.track("some_key", "TRAFFIC_type", "event_type", 1) is True + assert _logger.warning.mock_calls == [ + mocker.call("%s: %s '%s' should be all lowercase - converting string to lowercase", 'track', 'traffic type', 'TRAFFIC_type') + ] + + assert await client.track("some_key", "traffic_type", None, 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'track', 'event_type', 'event_type') + ] + + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "", 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an empty %s, %s must be a non-empty string.", 'track', 'event_type', 'event_type') + ] + + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", True, 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'event_type', 'event_type') + ] + + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", [], 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'event_type', 'event_type') + ] + + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", 12345, 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'event_type', 'event_type') + ] + + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "@@", 1) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed %s, %s must adhere to the regular " + "expression %s. This means " + "%s must be alphanumeric, cannot be more than %s " + "characters long, and can only include a dash, underscore, " + "period, or colon as separators of alphanumeric characters.", + 'track', '@@', 'an event name', '^[a-zA-Z0-9][-_.:a-zA-Z0-9]{0,79}$', 'an event name', 80) + ] + + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "event_type", None) is True + assert _logger.error.mock_calls == [] + + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "event_type", 1) is True + assert _logger.error.mock_calls == [] + + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "event_type", 1.23) is True + assert _logger.error.mock_calls == [] + + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "event_type", "test") is False + assert _logger.error.mock_calls == [ + mocker.call("track: value must be a number.") + ] + + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "event_type", True) is False + assert _logger.error.mock_calls == [ + mocker.call("track: value must be a number.") + ] + + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "event_type", []) is False + assert _logger.error.mock_calls == [ + mocker.call("track: value must be a number.") + ] + + # Test traffic type existance + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(factory).ready = ready_property + + # Test that it doesn't warn if tt is cached, not in localhost mode and sdk is ready + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "event_type", None) is True + assert _logger.error.mock_calls == [] + assert _logger.warning.mock_calls == [] + + # Test that it does warn if tt is cached, not in localhost mode and sdk is ready + async def is_valid_traffic_type(*_): + return False + split_storage_mock.is_valid_traffic_type = is_valid_traffic_type + + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "event_type", None) is True + assert _logger.error.mock_calls == [] + assert _logger.warning.mock_calls == [mocker.call( + 'track: Traffic Type %s does not have any corresponding Feature flags in this environment, ' + 'make sure you\'re tracking your events to a valid traffic type defined ' + 'in the Split user interface.', + 'traffic_type' + )] + + # Test that it does not warn when in localhost mode. + factory._sdk_key = 'localhost' + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "event_type", None) is True + assert _logger.error.mock_calls == [] + assert _logger.warning.mock_calls == [] + + # Test that it does not warn when not in localhost mode and not ready + factory._sdk_key = 'not-localhost' + ready_property.return_value = False + type(factory).ready = ready_property + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "event_type", None) is True + assert _logger.error.mock_calls == [] + assert _logger.warning.mock_calls == [] + + # Test track with invalid properties + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "event_type", 1, []) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: properties must be of type dictionary.", "track") + ] + + # Test track with invalid properties + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "event_type", 1, True) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: properties must be of type dictionary.", "track") + ] + + # Test track with properties + props1 = { + "test1": "test", + "test2": 1, + "test3": True, + "test4": None, + "test5": [], + 2: "t", + } + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "event_type", 1, props1) is True + assert _logger.warning.mock_calls == [ + mocker.call("%s: Property %s is of invalid type. Setting value to None", "track", []) + ] + + # Test track with more than 300 properties + props2 = dict() + for i in range(301): + props2[str(i)] = i + _logger.reset_mock() + assert await client.track("some_key", "traffic_type", "event_type", 1, props2) is True + assert _logger.warning.mock_calls == [ + mocker.call("%s: Event has more than 300 properties. Some of them will be trimmed when processed", "track") + ] + + # Test track with properties higher than 32kb + _logger.reset_mock() + props3 = dict() + for i in range(100, 210): + props3["prop" + str(i)] = "a" * 300 + assert await client.track("some_key", "traffic_type", "event_type", 1, props3) is False + assert _logger.error.mock_calls == [ + mocker.call("%s: The maximum size allowed for the properties is 32768 bytes. Current one is 32952 bytes. Event not queued", "track") + ] + await factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments(self, mocker): + """Test getTreatments() method.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + split_mock = mocker.Mock(spec=Split) + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] + storage_mock = mocker.Mock(spec=SplitStorage) + async def get(*_): + return split_mock + storage_mock.get = get + async def get_change_number(*_): + return 1 + storage_mock.get_change_number = get_change_number + async def fetch_many(*_): + return { + 'some_feature': split_mock, + 'some': split_mock, + } + storage_mock.fetch_many = fetch_many + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorageAsync) + async def fetch_many_rbs(*_): + return {} + rbs_storage.fetch_many = fetch_many_rbs + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + recorder = StandardRecorderAsync(impmanager, mocker.Mock(spec=EventStorage), mocker.Mock(spec=ImpressionStorage), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + impmanager, + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock(), + None + ) + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + + client = ClientAsync(factory, recorder, events_manager, mocker.Mock(), FallbackTreatmentCalculator(None)) + async def record_treatment_stats(*_): + pass + client._recorder.record_treatment_stats = record_treatment_stats + + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + + assert await client.get_treatments(None, ['some_feature']) == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatments', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatments("", ['some_feature']) == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatments', 'key', 'key') + ] + + key = ''.join('a' for _ in range(0, 255)) + _logger.reset_mock() + assert await client.get_treatments(key, ['some_feature']) == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatments', 'key', 250) + ] + + split_mock.name = 'some_feature' + _logger.reset_mock() + assert await client.get_treatments(12345, ['some_feature']) == {'some_feature': 'default_treatment'} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatments', 'key', 12345) + ] + + _logger.reset_mock() + assert await client.get_treatments(True, ['some_feature']) == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatments([], ['some_feature']) == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatments('some_key', None) == {} + assert _logger.error.mock_calls == [ + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments') + ] + + _logger.reset_mock() + assert await client.get_treatments('some_key', True) == {} + assert _logger.error.mock_calls == [ + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments') + ] + + _logger.reset_mock() + assert await client.get_treatments('some_key', 'some_string') == {} + assert _logger.error.mock_calls == [ + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments') + ] + + _logger.reset_mock() + assert await client.get_treatments('some_key', []) == {} + assert _logger.error.mock_calls == [ + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments') + ] + + _logger.reset_mock() + assert await client.get_treatments('some_key', [None, None]) == {} + assert _logger.error.mock_calls == [ + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments') + ] + + _logger.reset_mock() + assert await client.get_treatments('some_key', [True]) == {} + assert mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments') in _logger.error.mock_calls + + _logger.reset_mock() + assert await client.get_treatments('some_key', ['', '']) == {} + assert mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments') in _logger.error.mock_calls + + _logger.reset_mock() + assert await client.get_treatments('some_key', ['some_feature ']) == {'some_feature': 'default_treatment'} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatments', 'feature flag name', 'some_feature ') + ] + + _logger.reset_mock() + async def fetch_many(*_): + return { + 'some_feature': None + } + storage_mock.fetch_many = fetch_many + + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + mocker.patch('splitio.client.client._LOGGER', new=_logger) + assert await client.get_treatments('matching_key', ['some_feature'], None) == {'some_feature': CONTROL} + assert _logger.warning.mock_calls == [ + mocker.call( + "%s: you passed \"%s\" that does not exist in this environment, " + "please double check what Feature flags exist in the Split user interface.", + 'get_treatments', + 'some_feature' + ) + ] + await factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_with_config(self, mocker): + """Test getTreatments() method.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + split_mock = mocker.Mock(spec=Split) + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] + + storage_mock = mocker.Mock(spec=SplitStorage) + async def get(*_): + return split_mock + storage_mock.get = get + async def get_change_number(*_): + return 1 + storage_mock.get_change_number = get_change_number + async def fetch_many(*_): + return { + 'some_feature': split_mock + } + storage_mock.fetch_many = fetch_many + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorageAsync) + async def fetch_many_rbs(*_): + return {} + rbs_storage.fetch_many = fetch_many_rbs + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + recorder = StandardRecorderAsync(impmanager, mocker.Mock(spec=EventStorage), mocker.Mock(spec=ImpressionStorage), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + impmanager, + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock(), + None + ) + split_mock.name = 'some_feature' + + def _configs(treatment): + return '{"some": "property"}' if treatment == 'default_treatment' else None + split_mock.get_configurations_for.side_effect = _configs + + client = ClientAsync(factory, mocker.Mock(), events_manager, mocker.Mock(), FallbackTreatmentCalculator(None)) + async def record_treatment_stats(*_): + pass + client._recorder.record_treatment_stats = record_treatment_stats + + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + + assert await client.get_treatments_with_config(None, ['some_feature']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatments_with_config', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config("", ['some_feature']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatments_with_config', 'key', 'key') + ] + + key = ''.join('a' for _ in range(0, 255)) + _logger.reset_mock() + assert await client.get_treatments_with_config(key, ['some_feature']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatments_with_config', 'key', 250) + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config(12345, ['some_feature']) == {'some_feature': ('default_treatment', '{"some": "property"}')} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatments_with_config', 'key', 12345) + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config(True, ['some_feature']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_with_config', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config([], ['some_feature']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_with_config', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config('some_key', None) == {} + assert _logger.error.mock_calls == [ + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments_with_config') + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config('some_key', True) == {} + assert _logger.error.mock_calls == [ + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments_with_config') + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config('some_key', 'some_string') == {} + assert _logger.error.mock_calls == [ + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments_with_config') + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config('some_key', []) == {} + assert _logger.error.mock_calls == [ + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments_with_config') + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config('some_key', [None, None]) == {} + assert _logger.error.mock_calls == [ + mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments_with_config') + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config('some_key', [True]) == {} + assert mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments_with_config') in _logger.error.mock_calls + + _logger.reset_mock() + assert await client.get_treatments_with_config('some_key', ['', '']) == {} + assert mocker.call('%s: feature flag names must be a non-empty array.', 'get_treatments_with_config') in _logger.error.mock_calls + + _logger.reset_mock() + assert await client.get_treatments_with_config('some_key', ['some_feature ']) == {'some_feature': ('default_treatment', '{"some": "property"}')} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatments_with_config', 'feature flag name', 'some_feature ') + ] + + _logger.reset_mock() + async def fetch_many(*_): + return { + 'some_feature': None + } + storage_mock.fetch_many = fetch_many + + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + mocker.patch('splitio.client.client._LOGGER', new=_logger) + assert await client.get_treatments('matching_key', ['some_feature'], None) == {'some_feature': CONTROL} + assert _logger.warning.mock_calls == [ + mocker.call( + "%s: you passed \"%s\" that does not exist in this environment, " + "please double check what Feature flags exist in the Split user interface.", + 'get_treatments', + 'some_feature' + ) + ] + await factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_set(self, mocker): + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + split_mock = mocker.Mock(spec=Split) + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] + storage_mock = mocker.Mock(spec=SplitStorage) + async def get(*_): + return split_mock + storage_mock.get = get + async def get_change_number(*_): + return 1 + storage_mock.get_change_number = get_change_number + async def fetch_many(*_): + return { + 'some_feature': split_mock, + 'some': split_mock, + } + storage_mock.fetch_many = fetch_many + async def get_feature_flags_by_sets(*_): + return ['some_feature'] + storage_mock.get_feature_flags_by_sets = get_feature_flags_by_sets + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorageAsync) + async def fetch_many_rbs(*_): + return {} + rbs_storage.fetch_many = fetch_many_rbs + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + recorder = StandardRecorderAsync(impmanager, mocker.Mock(spec=EventStorage), mocker.Mock(spec=ImpressionStorage), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock(), + None + ) + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + + client = ClientAsync(factory, recorder, events_manager, mocker.Mock(), FallbackTreatmentCalculator(None)) + async def record_treatment_stats(*_): + pass + client._recorder.record_treatment_stats = record_treatment_stats + + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + + assert await client.get_treatments_by_flag_set(None, 'some_flag') == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatments_by_flag_set', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatments_by_flag_set("", 'some_flag') == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatments_by_flag_set', 'key', 'key') + ] + + key = ''.join('a' for _ in range(0, 255)) + _logger.reset_mock() + assert await client.get_treatments_by_flag_set(key, 'some_flag') == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatments_by_flag_set', 'key', 250) + ] + + split_mock.name = 'some_feature' + _logger.reset_mock() + assert await client.get_treatments_by_flag_set(12345, 'some_flag') == {'some_feature': 'default_treatment'} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatments_by_flag_set', 'key', 12345) + ] + + _logger.reset_mock() + assert await client.get_treatments_by_flag_set(True, 'some_flag') == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_by_flag_set', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatments_by_flag_set([], 'some_flag') == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_by_flag_set', 'key', 'key') + ] + + _logger.reset_mock() + await client.get_treatments_by_flag_set('some_key', None) + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", + 'get_treatments_by_flag_set', 'flag set', 'flag set') + ] + + _logger.reset_mock() + await client.get_treatments_by_flag_set('some_key', "$$") + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed %s, %s must adhere to the regular " + "expression %s. This means " + "%s must be alphanumeric, cannot be more than %s " + "characters long, and can only include a dash, underscore, " + "period, or colon as separators of alphanumeric characters.", + 'get_treatments_by_flag_set', '$$', 'a flag set', '^[a-z0-9][_a-z0-9]{0,49}$', 'a flag set', 50) + ] + + _logger.reset_mock() + assert await client.get_treatments_by_flag_set('some_key', 'some_flag ') == {'some_feature': 'default_treatment'} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatments_by_flag_set', 'flag set', 'some_flag ') + ] + + _logger.reset_mock() + async def fetch_many(*_): + return { + 'some_feature': None + } + storage_mock.fetch_many = fetch_many + + async def get_feature_flags_by_sets(*_): + return [] + storage_mock.get_feature_flags_by_sets = get_feature_flags_by_sets + + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + mocker.patch('splitio.client.client._LOGGER', new=_logger) + assert await client.get_treatments_by_flag_set('matching_key', 'some_flag', None) == {} + assert _logger.warning.mock_calls == [ + mocker.call("%s: No valid Flag set or no feature flags found for evaluating treatments", "get_treatments_by_flag_set") + ] + await factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_sets(self, mocker): + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + split_mock = mocker.Mock(spec=Split) + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] + storage_mock = mocker.Mock(spec=SplitStorage) + async def get(*_): + return split_mock + storage_mock.get = get + async def get_change_number(*_): + return 1 + storage_mock.get_change_number = get_change_number + async def fetch_many(*_): + return { + 'some_feature': split_mock, + 'some': split_mock, + } + storage_mock.fetch_many = fetch_many + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorageAsync) + async def fetch_many_rbs(*_): + return {} + rbs_storage.fetch_many = fetch_many_rbs + + async def get_feature_flags_by_sets(*_): + return ['some_feature'] + storage_mock.get_feature_flags_by_sets = get_feature_flags_by_sets + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + recorder = StandardRecorderAsync(impmanager, mocker.Mock(spec=EventStorage), mocker.Mock(spec=ImpressionStorage), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock(), + None + ) + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + + client = ClientAsync(factory, recorder, events_manager, mocker.Mock(), FallbackTreatmentCalculator(None)) + async def record_treatment_stats(*_): + pass + client._recorder.record_treatment_stats = record_treatment_stats + + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + + assert await client.get_treatments_by_flag_sets(None, ['some_flag']) == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatments_by_flag_sets', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatments_by_flag_sets("", ['some_flag']) == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatments_by_flag_sets', 'key', 'key') + ] + + key = ''.join('a' for _ in range(0, 255)) + _logger.reset_mock() + assert await client.get_treatments_by_flag_sets(key, ['some_flag']) == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatments_by_flag_sets', 'key', 250) + ] + + split_mock.name = 'some_feature' + _logger.reset_mock() + assert await client.get_treatments_by_flag_sets(12345, ['some_flag']) == {'some_feature': 'default_treatment'} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatments_by_flag_sets', 'key', 12345) + ] + + _logger.reset_mock() + assert await client.get_treatments_by_flag_sets(True, ['some_flag']) == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_by_flag_sets', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatments_by_flag_sets([], ['some_flag']) == {'some_feature': CONTROL} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_by_flag_sets', 'key', 'key') + ] + + _logger.reset_mock() + await client.get_treatments_by_flag_sets('some_key', None) + assert _logger.warning.mock_calls == [ + mocker.call("%s: flag sets parameter type should be list object, parameter is discarded", "get_treatments_by_flag_sets") + ] + + _logger.reset_mock() + await client.get_treatments_by_flag_sets('some_key', [None]) + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", + 'get_treatments_by_flag_sets', 'flag set', 'flag set') + ] + + _logger.reset_mock() + await client.get_treatments_by_flag_sets('some_key', ["$$"]) + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed %s, %s must adhere to the regular " + "expression %s. This means " + "%s must be alphanumeric, cannot be more than %s " + "characters long, and can only include a dash, underscore, " + "period, or colon as separators of alphanumeric characters.", + 'get_treatments_by_flag_sets', '$$', 'a flag set', '^[a-z0-9][_a-z0-9]{0,49}$', 'a flag set', 50) + ] + + _logger.reset_mock() + assert await client.get_treatments_by_flag_sets('some_key', ['some_flag ']) == {'some_feature': 'default_treatment'} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatments_by_flag_sets', 'flag set', 'some_flag ') + ] + + _logger.reset_mock() + async def fetch_many(*_): + return { + 'some_feature': None + } + storage_mock.fetch_many = fetch_many + + async def get_feature_flags_by_sets(*_): + return [] + storage_mock.get_feature_flags_by_sets = get_feature_flags_by_sets + + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + mocker.patch('splitio.client.client._LOGGER', new=_logger) + assert await client.get_treatments_by_flag_sets('matching_key', ['some_flag'], None) == {} + assert _logger.warning.mock_calls == [ + mocker.call("%s: No valid Flag set or no feature flags found for evaluating treatments", "get_treatments_by_flag_sets") + ] + await factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_set(self, mocker): + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + split_mock = mocker.Mock(spec=Split) + def _configs(treatment): + return '{"some": "property"}' if treatment == 'default_treatment' else None + split_mock.get_configurations_for.side_effect = _configs + split_mock.name = 'some_feature' + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] + storage_mock = mocker.Mock(spec=SplitStorage) + async def get(*_): + return split_mock + storage_mock.get = get + async def get_change_number(*_): + return 1 + storage_mock.get_change_number = get_change_number + async def fetch_many(*_): + return { + 'some_feature': split_mock, + 'some': split_mock, + } + storage_mock.fetch_many = fetch_many + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorageAsync) + async def fetch_many_rbs(*_): + return {} + rbs_storage.fetch_many = fetch_many_rbs + async def get_feature_flags_by_sets(*_): + return ['some_feature'] + storage_mock.get_feature_flags_by_sets = get_feature_flags_by_sets + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + recorder = StandardRecorderAsync(impmanager, mocker.Mock(spec=EventStorage), mocker.Mock(spec=ImpressionStorage), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock(), + None + ) + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + + client = ClientAsync(factory, recorder, events_manager, mocker.Mock(), FallbackTreatmentCalculator(None)) + async def record_treatment_stats(*_): + pass + client._recorder.record_treatment_stats = record_treatment_stats + + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + + assert await client.get_treatments_with_config_by_flag_set(None, 'some_flag') == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_set', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config_by_flag_set("", 'some_flag') == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_set', 'key', 'key') + ] + + key = ''.join('a' for _ in range(0, 255)) + _logger.reset_mock() + assert await client.get_treatments_with_config_by_flag_set(key, 'some_flag') == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatments_with_config_by_flag_set', 'key', 250) + ] + + split_mock.name = 'some_feature' + _logger.reset_mock() + assert await client.get_treatments_with_config_by_flag_set(12345, 'some_flag') == {'some_feature': ('default_treatment', '{"some": "property"}')} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatments_with_config_by_flag_set', 'key', 12345) + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config_by_flag_set(True, 'some_flag') == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_set', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config_by_flag_set([], 'some_flag') == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_set', 'key', 'key') + ] + + _logger.reset_mock() + await client.get_treatments_with_config_by_flag_set('some_key', None) + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", + 'get_treatments_with_config_by_flag_set', 'flag set', 'flag set') + ] + + _logger.reset_mock() + await client.get_treatments_with_config_by_flag_set('some_key', "$$") + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed %s, %s must adhere to the regular " + "expression %s. This means " + "%s must be alphanumeric, cannot be more than %s " + "characters long, and can only include a dash, underscore, " + "period, or colon as separators of alphanumeric characters.", + 'get_treatments_with_config_by_flag_set', '$$', 'a flag set', '^[a-z0-9][_a-z0-9]{0,49}$', 'a flag set', 50) + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config_by_flag_set('some_key', 'some_flag ') == {'some_feature': ('default_treatment', '{"some": "property"}')} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatments_with_config_by_flag_set', 'flag set', 'some_flag ') + ] + + _logger.reset_mock() + async def fetch_many(*_): + return { + 'some_feature': None + } + storage_mock.fetch_many = fetch_many + + async def get_feature_flags_by_sets(*_): + return [] + storage_mock.get_feature_flags_by_sets = get_feature_flags_by_sets + + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + mocker.patch('splitio.client.client._LOGGER', new=_logger) + assert await client.get_treatments_with_config_by_flag_set('matching_key', 'some_flag', None) == {} + assert _logger.warning.mock_calls == [ + mocker.call("%s: No valid Flag set or no feature flags found for evaluating treatments", "get_treatments_with_config_by_flag_set") + ] + await factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_sets(self, mocker): + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + split_mock = mocker.Mock(spec=Split) + def _configs(treatment): + return '{"some": "property"}' if treatment == 'default_treatment' else None + split_mock.get_configurations_for.side_effect = _configs + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + type(split_mock).prerequisites = [] + storage_mock = mocker.Mock(spec=SplitStorage) + async def get(*_): + return split_mock + storage_mock.get = get + async def get_change_number(*_): + return 1 + storage_mock.get_change_number = get_change_number + async def fetch_many(*_): + return { + 'some_feature': split_mock, + 'some': split_mock, + } + storage_mock.fetch_many = fetch_many + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorageAsync) + async def fetch_many_rbs(*_): + return {} + rbs_storage.fetch_many = fetch_many_rbs + + async def get_feature_flags_by_sets(*_): + return ['some_feature'] + storage_mock.get_feature_flags_by_sets = get_feature_flags_by_sets + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + recorder = StandardRecorderAsync(impmanager, mocker.Mock(spec=EventStorage), mocker.Mock(spec=ImpressionStorage), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': rbs_storage, + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock(), + None + ) + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + + client = ClientAsync(factory, recorder, events_manager, mocker.Mock(), FallbackTreatmentCalculator(None)) + async def record_treatment_stats(*_): + pass + client._recorder.record_treatment_stats = record_treatment_stats + + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + + assert await client.get_treatments_with_config_by_flag_sets(None, ['some_flag']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_sets', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config_by_flag_sets("", ['some_flag']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_sets', 'key', 'key') + ] + + key = ''.join('a' for _ in range(0, 255)) + _logger.reset_mock() + assert await client.get_treatments_with_config_by_flag_sets(key, ['some_flag']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatments_with_config_by_flag_sets', 'key', 250) + ] + + split_mock.name = 'some_feature' + _logger.reset_mock() + assert await client.get_treatments_with_config_by_flag_sets(12345, ['some_flag']) == {'some_feature': ('default_treatment', '{"some": "property"}')} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatments_with_config_by_flag_sets', 'key', 12345) + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config_by_flag_sets(True, ['some_flag']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_sets', 'key', 'key') + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config_by_flag_sets([], ['some_flag']) == {'some_feature': (CONTROL, None)} + assert _logger.error.mock_calls == [ + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments_with_config_by_flag_sets', 'key', 'key') + ] + + _logger.reset_mock() + await client.get_treatments_with_config_by_flag_sets('some_key', None) + assert _logger.warning.mock_calls == [ + mocker.call("%s: flag sets parameter type should be list object, parameter is discarded", "get_treatments_with_config_by_flag_sets") + ] + + _logger.reset_mock() + await client.get_treatments_with_config_by_flag_sets('some_key', [None]) + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", + 'get_treatments_with_config_by_flag_sets', 'flag set', 'flag set') + ] + + _logger.reset_mock() + await client.get_treatments_with_config_by_flag_sets('some_key', ["$$"]) + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed %s, %s must adhere to the regular " + "expression %s. This means " + "%s must be alphanumeric, cannot be more than %s " + "characters long, and can only include a dash, underscore, " + "period, or colon as separators of alphanumeric characters.", + 'get_treatments_with_config_by_flag_sets', '$$', 'a flag set', '^[a-z0-9][_a-z0-9]{0,49}$', 'a flag set', 50) + ] + + _logger.reset_mock() + assert await client.get_treatments_with_config_by_flag_sets('some_key', ['some_flag ']) == {'some_feature': ('default_treatment', '{"some": "property"}')} + assert _logger.warning.mock_calls == [ + mocker.call('%s: %s \'%s\' has extra whitespace, trimming.', 'get_treatments_with_config_by_flag_sets', 'flag set', 'some_flag ') + ] + + _logger.reset_mock() + async def fetch_many(*_): + return { + 'some_feature': None + } + storage_mock.fetch_many = fetch_many + + async def get_feature_flags_by_sets(*_): + return [] + storage_mock.get_feature_flags_by_sets = get_feature_flags_by_sets + + ready_mock = mocker.PropertyMock() + ready_mock.return_value = True + type(factory).ready = ready_mock + mocker.patch('splitio.client.client._LOGGER', new=_logger) + assert await client.get_treatments_with_config_by_flag_sets('matching_key', ['some_flag'], None) == {} + assert _logger.warning.mock_calls == [ + mocker.call("%s: No valid Flag set or no feature flags found for evaluating treatments", "get_treatments_with_config_by_flag_sets") + ] + await factory.destroy() + + + def test_flag_sets_validation(self): + """Test sanitization for flag sets.""" + flag_sets = input_validator.validate_flag_sets([' set1', 'set2 ', 'set3'], 'method') + assert sorted(flag_sets) == ['set1', 'set2', 'set3'] + + flag_sets = input_validator.validate_flag_sets(['1set', '_set2'], 'method') + assert flag_sets == ['1set'] + + flag_sets = input_validator.validate_flag_sets(['Set1', 'SET2'], 'method') + assert sorted(flag_sets) == ['set1', 'set2'] + + flag_sets = input_validator.validate_flag_sets(['se\t1', 's/et2', 's*et3', 's!et4', 'se@t5', 'se#t5', 'se$t5', 'se^t5', 'se%t5', 'se&t5'], 'method') + assert flag_sets == [] + + flag_sets = input_validator.validate_flag_sets(['set4', 'set1', 'set3', 'set1'], 'method') + assert sorted(flag_sets) == ['set1', 'set3', 'set4'] + + flag_sets = input_validator.validate_flag_sets(['w' * 50, 's' * 51], 'method') + assert flag_sets == ['w' * 50] + + flag_sets = input_validator.validate_flag_sets('set1', 'method') + assert flag_sets == [] + + flag_sets = input_validator.validate_flag_sets([12, 33], 'method') + assert flag_sets == [] + + +class ManagerInputValidationTests(object): #pylint: disable=too-few-public-methods + """Manager input validation test cases.""" + + def test_split_(self, mocker): + """Test split input validation.""" + storage_mock = mocker.Mock(spec=SplitStorage) + split_mock = mocker.Mock(spec=Split) + storage_mock.get.return_value = split_mock + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, mocker.Mock(spec=EventStorage), mocker.Mock(spec=ImpressionStorage), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': mocker.Mock(spec=RuleBasedSegmentsStorage), + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + mocker.Mock(), + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + + manager = SplitManager(factory) + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + + assert manager.split(None) is None + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'split', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert manager.split("") is None + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an empty %s, %s must be a non-empty string.", 'split', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert manager.split(True) is None + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'split', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert manager.split([]) is None + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'split', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + manager.split('some_split') + assert split_mock.to_split_view.mock_calls == [mocker.call()] + assert _logger.error.mock_calls == [] + + _logger.reset_mock() + split_mock.reset_mock() + storage_mock.get.return_value = None + manager.split('nonexistant-split') + assert split_mock.to_split_view.mock_calls == [] + assert _logger.warning.mock_calls == [mocker.call( + "split: you passed \"%s\" that does not exist in this environment, " + "please double check what Feature flags exist in the Split user interface.", + 'nonexistant-split' + )] + +class ManagerInputValidationAsyncTests(object): #pylint: disable=too-few-public-methods + """Manager input validation test cases.""" + + @pytest.mark.asyncio + async def test_split_(self, mocker): + """Test split input validation.""" + internal_events_queue = asyncio.Queue() + events_manager = mocker.Mock(EventsManagerAsync) + async def notify_internal_event(sdk_internal_event, event_metadata): + pass + events_manager.notify_internal_event = notify_internal_event + + storage_mock = mocker.Mock(spec=SplitStorage) + split_mock = mocker.Mock(spec=Split) + async def get(*_): + return split_mock + storage_mock.get = get + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + recorder = StandardRecorderAsync(impmanager, mocker.Mock(spec=EventStorage), mocker.Mock(spec=ImpressionStorage), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactoryAsync(mocker.Mock(), + { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'rule_based_segments': mocker.Mock(spec=RuleBasedSegmentsStorage), + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + }, + mocker.Mock(), + recorder, + internal_events_queue, + events_manager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock() + ) + + manager = SplitManagerAsync(factory) + _logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=_logger) + + assert await manager.split(None) is None + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'split', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert await manager.split("") is None + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an empty %s, %s must be a non-empty string.", 'split', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert await manager.split(True) is None + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'split', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + assert await manager.split([]) is None + assert _logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'split', 'feature_flag_name', 'feature_flag_name') + ] + + _logger.reset_mock() + await manager.split('some_split') + assert split_mock.to_split_view.mock_calls == [mocker.call()] + assert _logger.error.mock_calls == [] + + _logger.reset_mock() + split_mock.reset_mock() + async def get(*_): + return None + storage_mock.get = get + + await manager.split('nonexistant-split') + assert split_mock.to_split_view.mock_calls == [] + assert _logger.warning.mock_calls == [mocker.call( + "split: you passed \"%s\" that does not exist in this environment, " + "please double check what Feature flags exist in the Split user interface.", + 'nonexistant-split' + )] + +class FactoryInputValidationTests(object): #pylint: disable=too-few-public-methods + """Factory instantiation input validation test cases.""" + + def test_input_validation_factory(self, mocker): + """Test the input validators for factory instantiation.""" + logger = mocker.Mock(spec=logging.Logger) + mocker.patch('splitio.client.input_validator._LOGGER', new=logger) + + assert get_factory(None) is None + assert logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'factory_instantiation', 'sdk_key', 'sdk_key') + ] + + logger.reset_mock() + assert get_factory('') is None + assert logger.error.mock_calls == [ + mocker.call("%s: you passed an empty %s, %s must be a non-empty string.", 'factory_instantiation', 'sdk_key', 'sdk_key') + ] + + logger.reset_mock() + assert get_factory(True) is None + assert logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'factory_instantiation', 'sdk_key', 'sdk_key') + ] + + logger.reset_mock() + try: + f = get_factory(True, config={'redisHost': 'localhost'}) + except: + pass + assert logger.error.mock_calls == [] + f.destroy() + + +class FactoryInputValidationAsyncTests(object): #pylint: disable=too-few-public-methods + """Factory instantiation input validation test cases.""" + + @pytest.mark.asyncio + async def test_input_validation_factory(self, mocker): + """Test the input validators for factory instantiation.""" + logger = mocker.Mock(spec=logging.Logger) + mocker.patch('splitio.client.input_validator._LOGGER', new=logger) + + assert await get_factory_async(None) is None + assert logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'factory_instantiation', 'sdk_key', 'sdk_key') + ] + + logger.reset_mock() + assert await get_factory_async('') is None + assert logger.error.mock_calls == [ + mocker.call("%s: you passed an empty %s, %s must be a non-empty string.", 'factory_instantiation', 'sdk_key', 'sdk_key') + ] + + logger.reset_mock() + assert await get_factory_async(True) is None + assert logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'factory_instantiation', 'sdk_key', 'sdk_key') + ] + + logger.reset_mock() + try: + f = await get_factory_async(True, config={'redisHost': 'localhost'}) + except: + pass + assert logger.error.mock_calls == [] + await f.destroy() + +class PluggableInputValidationTests(object): #pylint: disable=too-few-public-methods + """Pluggable adapter instance validation test cases.""" + + class mock_adapter0(): + def set(self, key, value): + print(key) + + class mock_adapter1(object): + def set(self, key, value): + print(key) + + class mock_adapter2(mock_adapter1): + def get(self, key): + print(key) + + def get_items(self, key): + print(key) + + def get_many(self, keys): + print(keys) + + def push_items(self, key, *value): + print(key) + + def delete(self, key): + print(key) + + def increment(self, key, value): + print(key) + + def decrement(self, key, value): + print(key) + + def get_keys_by_prefix(self, prefix): + print(prefix) + + def get_many(self, keys): + print(keys) + + def add_items(self, key, added_items): + print(key) + + def remove_items(self, key, removed_items): + print(key) + + def item_contains(self, key, item): + print(key) + + def get_items_count(self, key): + print(key) + + class mock_adapter3(mock_adapter2): + def expire(self, key): + print(key) + + class mock_adapter4(mock_adapter2): + def expire(self, key, value, till): + print(key) + + def test_validate_pluggable_adapter(self): + # missing storageWrapper config parameter + assert(not input_validator.validate_pluggable_adapter({'storageType': 'pluggable'})) + + # ignore if storage type is not pluggable + assert(input_validator.validate_pluggable_adapter({'storageType': 'memory'})) + + # mock adapter is not derived from object class + assert(not input_validator.validate_pluggable_adapter({'storageType': 'pluggable', 'pe': self.mock_adapter0()})) + + # mock adapter missing many functions + assert(not input_validator.validate_pluggable_adapter({'storageType': 'pluggable', 'storageWrapper': self.mock_adapter1()})) + + # mock adapter missing expire function + assert(not input_validator.validate_pluggable_adapter({'storageType': 'pluggable', 'storageWrapper': self.mock_adapter2()})) + + # mock adapter expire function has incrrect args count + assert(not input_validator.validate_pluggable_adapter({'storageType': 'pluggable', 'storageWrapper': self.mock_adapter3()})) + + # expected mock adapter should pass + assert(input_validator.validate_pluggable_adapter({'storageType': 'pluggable', 'storageWrapper': self.mock_adapter4()})) + + # using string type prefix should pass + assert(input_validator.validate_pluggable_adapter({'storageType': 'pluggable', 'storagePrefix': 'myprefix', 'storageWrapper': self.mock_adapter4()})) + + # using non-string type prefix should not pass + assert(not input_validator.validate_pluggable_adapter({'storageType': 'pluggable', 'storagePrefix': 'myprefix', 123: self.mock_adapter4()})) + + def test_sanitize_flag_sets(self): + """Test sanitization for flag sets.""" + flag_sets = input_validator.validate_flag_sets([' set1', 'set2 ', 'set3'], 'm') + assert sorted(flag_sets) == ['set1', 'set2', 'set3'] + + flag_sets = input_validator.validate_flag_sets(['1set', '_set2'], 'm') + assert flag_sets == ['1set'] + + flag_sets = input_validator.validate_flag_sets(['Set1', 'SET2'], 'm') + assert sorted(flag_sets) == ['set1', 'set2'] + + flag_sets = input_validator.validate_flag_sets(['se\t1', 's/et2', 's*et3', 's!et4', 'se@t5', 'se#t5', 'se$t5', 'se^t5', 'se%t5', 'se&t5'], 'm') + assert flag_sets == [] + + flag_sets = input_validator.validate_flag_sets(['set4', 'set1', 'set3', 'set1'], 'm') + assert sorted(flag_sets) == ['set1', 'set3', 'set4'] + + flag_sets = input_validator.validate_flag_sets(['w' * 50, 's' * 51], 'm') + assert flag_sets == ['w' * 50] + + flag_sets = input_validator.validate_flag_sets('set1', 'm') + assert flag_sets == [] + + flag_sets = input_validator.validate_flag_sets([12, 33], 'm') + + assert flag_sets == [] diff --git a/tests/client/test_localhost.py b/tests/client/test_localhost.py index d211bf2c..598d6100 100644 --- a/tests/client/test_localhost.py +++ b/tests/client/test_localhost.py @@ -6,7 +6,7 @@ from splitio.sync.split import LocalSplitSynchronizer from splitio.models.splits import Split from splitio.models.grammar.matchers import AllKeysMatcher -from splitio.storage import SplitStorage +from splitio.storage import SplitStorage, RuleBasedSegmentsStorage class LocalHostStoragesTests(object): @@ -72,7 +72,7 @@ def test_make_whitelist_condition(self): def test_parse_legacy_file(self): """Test that aprsing a legacy file works.""" filename = os.path.join(os.path.dirname(__file__), 'files', 'file1.split') - splits = LocalSplitSynchronizer._read_splits_from_legacy_file(filename) + splits = LocalSplitSynchronizer._read_feature_flags_from_legacy_file(filename) assert len(splits) == 2 for split in splits.values(): assert isinstance(split, Split) @@ -84,7 +84,7 @@ def test_parse_legacy_file(self): def test_parse_yaml_file(self): """Test that parsing a yaml file works.""" filename = os.path.join(os.path.dirname(__file__), 'files', 'file2.yaml') - splits = LocalSplitSynchronizer._read_splits_from_yaml_file(filename) + splits = LocalSplitSynchronizer._read_feature_flags_from_yaml_file(filename) assert len(splits) == 4 for split in splits.values(): assert isinstance(split, Split) @@ -112,48 +112,48 @@ def test_update_splits(self, mocker): parse_yaml.return_value = {} storage_mock = mocker.Mock(spec=SplitStorage) storage_mock.get_split_names.return_value = [] - + rbs = mocker.Mock(spec=RuleBasedSegmentsStorage) parse_legacy.reset_mock() parse_yaml.reset_mock() - sync = LocalSplitSynchronizer('something', storage_mock) - sync._read_splits_from_legacy_file = parse_legacy - sync._read_splits_from_yaml_file = parse_yaml + sync = LocalSplitSynchronizer('something', storage_mock, rbs) + sync._read_feature_flags_from_legacy_file = parse_legacy + sync._read_feature_flags_from_yaml_file = parse_yaml sync.synchronize_splits() assert parse_legacy.mock_calls == [mocker.call('something')] assert parse_yaml.mock_calls == [] parse_legacy.reset_mock() parse_yaml.reset_mock() - sync = LocalSplitSynchronizer('something.yaml', storage_mock) - sync._read_splits_from_legacy_file = parse_legacy - sync._read_splits_from_yaml_file = parse_yaml + sync = LocalSplitSynchronizer('something.yaml', storage_mock, rbs) + sync._read_feature_flags_from_legacy_file = parse_legacy + sync._read_feature_flags_from_yaml_file = parse_yaml sync.synchronize_splits() assert parse_legacy.mock_calls == [] assert parse_yaml.mock_calls == [mocker.call('something.yaml')] parse_legacy.reset_mock() parse_yaml.reset_mock() - sync = LocalSplitSynchronizer('something.yml', storage_mock) - sync._read_splits_from_legacy_file = parse_legacy - sync._read_splits_from_yaml_file = parse_yaml + sync = LocalSplitSynchronizer('something.yml', storage_mock, rbs) + sync._read_feature_flags_from_legacy_file = parse_legacy + sync._read_feature_flags_from_yaml_file = parse_yaml sync.synchronize_splits() assert parse_legacy.mock_calls == [] assert parse_yaml.mock_calls == [mocker.call('something.yml')] parse_legacy.reset_mock() parse_yaml.reset_mock() - sync = LocalSplitSynchronizer('something.YAML', storage_mock) - sync._read_splits_from_legacy_file = parse_legacy - sync._read_splits_from_yaml_file = parse_yaml + sync = LocalSplitSynchronizer('something.YAML', storage_mock, rbs) + sync._read_feature_flags_from_legacy_file = parse_legacy + sync._read_feature_flags_from_yaml_file = parse_yaml sync.synchronize_splits() assert parse_legacy.mock_calls == [] assert parse_yaml.mock_calls == [mocker.call('something.YAML')] parse_legacy.reset_mock() parse_yaml.reset_mock() - sync = LocalSplitSynchronizer('yaml', storage_mock) - sync._read_splits_from_legacy_file = parse_legacy - sync._read_splits_from_yaml_file = parse_yaml + sync = LocalSplitSynchronizer('yaml', storage_mock, rbs) + sync._read_feature_flags_from_legacy_file = parse_legacy + sync._read_feature_flags_from_yaml_file = parse_yaml sync.synchronize_splits() assert parse_legacy.mock_calls == [mocker.call('yaml')] assert parse_yaml.mock_calls == [] diff --git a/tests/client/test_manager.py b/tests/client/test_manager.py index eeb2f304..c5454f67 100644 --- a/tests/client/test_manager.py +++ b/tests/client/test_manager.py @@ -1,19 +1,69 @@ """SDK main manager test module.""" +import pytest +import queue +import asyncio from splitio.client.factory import SplitFactory -from splitio.client.manager import SplitManager, _LOGGER as _logger +from splitio.client.manager import SplitManager, SplitManagerAsync, _LOGGER as _logger +from splitio.models import splits +from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemoryTelemetryStorageAsync, InMemorySplitStorage, InMemorySplitStorageAsync +from splitio.engine.impressions.impressions import Manager as ImpressionManager +from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageProducerAsync, TelemetryStorageConsumer, TelemetryStorageConsumerAsync +from splitio.recorder.recorder import StandardRecorder, StandardRecorderAsync +from tests.integration import splits_json - -class ManagerTests(object): # pylint: disable=too-few-public-methods +class SplitManagerTests(object): # pylint: disable=too-few-public-methods """Split manager test cases.""" + def test_manager_calls(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) + + factory = mocker.Mock(spec=SplitFactory) + factory._storages = {'split': storage} + factory._telemetry_init_producer = telemetry_producer._telemetry_init_producer + factory.destroyed = False + factory._waiting_fork.return_value = False + factory.ready = True + + manager = SplitManager(factory) + split1 = splits.from_raw(splits_json["splitChange1_1"]['ff']['d'][0]) + split2 = splits.from_raw(splits_json["splitChange1_3"]['ff']['d'][0]) + storage.update([split1, split2], [], -1) + manager._storage = storage + + assert manager.split_names() == ['SPLIT_2', 'SPLIT_1'] + assert manager.split('SPLIT_3') is None + assert manager.split('SPLIT_2') == split1.to_split_view() + assert manager.splits() == [split.to_split_view() for split in storage.get_all_splits()] + def test_evaluations_before_running_post_fork(self, mocker): destroyed_property = mocker.PropertyMock() destroyed_property.return_value = False - factory = mocker.Mock(spec=SplitFactory) - factory._waiting_fork.return_value = True - type(factory).destroyed = destroyed_property + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + recorder = StandardRecorder(impmanager, mocker.Mock(), mocker.Mock(), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + {'splits': mocker.Mock(), + 'segments': mocker.Mock(), + 'impressions': mocker.Mock(), + 'events': mocker.Mock()}, + mocker.Mock(), + recorder, + mocker.Mock(), + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock(), + True + ) expected_msg = [ mocker.call('Client is not ready - no calls possible') @@ -34,3 +84,79 @@ def test_evaluations_before_running_post_fork(self, mocker): assert manager.splits() == [] assert _logger.error.mock_calls == expected_msg _logger.reset_mock() + + +class SplitManagerAsyncTests(object): # pylint: disable=too-few-public-methods + """Split manager test cases.""" + + @pytest.mark.asyncio + async def test_manager_calls(self, mocker): + internal_events_queue = asyncio.Queue() + telemetry_storage = InMemoryTelemetryStorageAsync() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + storage = InMemorySplitStorageAsync(internal_events_queue) + + factory = mocker.Mock(spec=SplitFactory) + factory._storages = {'split': storage} + factory._telemetry_init_producer = telemetry_producer._telemetry_init_producer + factory.destroyed = False + factory._waiting_fork.return_value = False + factory.ready = True + + manager = SplitManagerAsync(factory) + split1 = splits.from_raw(splits_json["splitChange1_1"]['ff']['d'][0]) + split2 = splits.from_raw(splits_json["splitChange1_3"]['ff']['d'][0]) + await storage.update([split1, split2], [], -1) + manager._storage = storage + + assert await manager.split_names() == ['SPLIT_2', 'SPLIT_1'] + assert await manager.split('SPLIT_3') is None + assert await manager.split('SPLIT_2') == split1.to_split_view() + assert await manager.splits() == [split.to_split_view() for split in await storage.get_all_splits()] + + @pytest.mark.asyncio + async def test_evaluations_before_running_post_fork(self, mocker): + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + impmanager = mocker.Mock(spec=ImpressionManager) + telemetry_storage = InMemoryTelemetryStorageAsync() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + recorder = StandardRecorderAsync(impmanager, mocker.Mock(), mocker.Mock(), telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_producer.get_telemetry_runtime_producer()) + factory = SplitFactory(mocker.Mock(), + {'splits': mocker.Mock(), + 'segments': mocker.Mock(), + 'impressions': mocker.Mock(), + 'events': mocker.Mock()}, + mocker.Mock(), + recorder, + mocker.Mock(), + mocker.Mock(), + impmanager, + mocker.Mock(), + telemetry_producer, + telemetry_producer.get_telemetry_init_producer(), + mocker.Mock(), + True + ) + + expected_msg = [ + mocker.call('Client is not ready - no calls possible') + ] + + manager = SplitManagerAsync(factory) + _logger = mocker.Mock() + mocker.patch('splitio.client.manager._LOGGER', new=_logger) + + assert await manager.split_names() == [] + assert _logger.error.mock_calls == expected_msg + _logger.reset_mock() + + assert await manager.split('some_feature') is None + assert _logger.error.mock_calls == expected_msg + _logger.reset_mock() + + assert await manager.splits() == [] + assert _logger.error.mock_calls == expected_msg + _logger.reset_mock() diff --git a/tests/client/test_utils.py b/tests/client/test_utils.py index 807dc9c4..98d9d8f6 100644 --- a/tests/client/test_utils.py +++ b/tests/client/test_utils.py @@ -13,30 +13,29 @@ class ClientUtilsTests(object): def test_get_metadata(self, mocker): """Test the get_metadata function.""" - get_ip_mock = mocker.Mock() - get_host_mock = mocker.Mock() - mocker.patch('splitio.client.util._get_ip', new=get_ip_mock) - mocker.patch('splitio.client.util._get_hostname', new=get_host_mock) - meta = util.get_metadata({'machineIp': 'some_ip', 'machineName': 'some_machine_name'}) - assert get_ip_mock.mock_calls == [] - assert get_host_mock.mock_calls == [] assert meta.instance_ip == 'some_ip' assert meta.instance_name == 'some_machine_name' assert meta.sdk_version == 'python-' + __version__ - meta = util.get_metadata(config.DEFAULT_CONFIG) - assert get_ip_mock.mock_calls == [mocker.call()] - assert get_host_mock.mock_calls == [mocker.call(mocker.ANY)] - cfg = DEFAULT_CONFIG.copy() cfg.update({'IPAddressesEnabled': False}) meta = util.get_metadata(cfg) assert meta.instance_ip == 'NA' assert meta.instance_name == 'NA' - get_ip_mock.reset_mock() - get_host_mock.reset_mock() - meta = util.get_metadata({}) - assert get_ip_mock.mock_calls == [mocker.call()] - assert get_host_mock.mock_calls == [mocker.call(mocker.ANY)] + meta = util.get_metadata(config.DEFAULT_CONFIG) + ip_address, hostname = util._get_hostname_and_ip(config.DEFAULT_CONFIG) + assert meta.instance_ip != 'NA' + assert meta.instance_name != 'NA' + assert meta.instance_ip == ip_address + assert meta.instance_name == hostname + + self.called = 0 + def get_hostname_and_ip_mock(any): + self.called += 0 + return mocker.Mock(), mocker.Mock() + mocker.patch('splitio.client.util._get_hostname_and_ip', new=get_hostname_and_ip_mock) + + meta = util.get_metadata(config.DEFAULT_CONFIG) + self.called = 1 \ No newline at end of file diff --git a/tests/engine/files/rule_base_segments.json b/tests/engine/files/rule_base_segments.json new file mode 100644 index 00000000..70b64b32 --- /dev/null +++ b/tests/engine/files/rule_base_segments.json @@ -0,0 +1,62 @@ +{"ff": {"d": [], "t": -1, "s": -1}, +"rbs": {"t": -1, "s": -1, "d": + [{ + "changeNumber": 5, + "name": "dependent_rbs", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded":{"keys":["mauro@split.io","gaston@split.io"],"segments":[]}, + "conditions": [ + { + "conditionType": "WHITELIST", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user", + "attribute": "email" + }, + "matcherType": "ENDS_WITH", + "negate": false, + "whitelistMatcherData": { + "whitelist": [ + "@split.io" + ] + } + } + ] + } + } + ]}, + { + "changeNumber": 5, + "name": "sample_rule_based_segment", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded": { + "keys": [], + "segments": [] + }, + "conditions": [ + { + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user" + }, + "matcherType": "IN_RULE_BASED_SEGMENT", + "negate": false, + "userDefinedSegmentMatcherData": { + "segmentName": "dependent_rbs" + } + } + ] + } + } + ] + }] +}} diff --git a/tests/engine/files/rule_base_segments2.json b/tests/engine/files/rule_base_segments2.json new file mode 100644 index 00000000..2f77ecd5 --- /dev/null +++ b/tests/engine/files/rule_base_segments2.json @@ -0,0 +1,67 @@ +{"ff": {"d": [], "t": -1, "s": -1}, +"rbs": {"t": -1, "s": -1, "d": [ + { + "changeNumber": 5, + "name": "sample_rule_based_segment", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded":{ + "keys":["mauro@split.io","gaston@split.io"], + "segments":[{"type":"rule-based", "name":"no_excludes"}] + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user", + "attribute": "email" + }, + "matcherType": "STARTS_WITH", + "negate": false, + "whitelistMatcherData": { + "whitelist": [ + "bilal" + ] + } + } + ] + } + } + ] + }, + { + "changeNumber": 5, + "name": "no_excludes", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded":{ + "keys":["bilal2@split.io"], + "segments":[] + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user", + "attribute": "email" + }, + "matcherType": "ENDS_WITH", + "negate": false, + "whitelistMatcherData": { + "whitelist": [ + "@split.io" + ] + } + } + ] + } + } + ] + } +]}} diff --git a/tests/engine/files/rule_base_segments3.json b/tests/engine/files/rule_base_segments3.json new file mode 100644 index 00000000..f738f3f7 --- /dev/null +++ b/tests/engine/files/rule_base_segments3.json @@ -0,0 +1,35 @@ +{"ff": {"d": [], "t": -1, "s": -1}, +"rbs": {"t": -1, "s": -1, "d": [ + { + "changeNumber": 5, + "name": "sample_rule_based_segment", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded":{ + "keys":["mauro@split.io","gaston@split.io"], + "segments":[{"type":"standard", "name":"segment1"}] + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user", + "attribute": "email" + }, + "matcherType": "ENDS_WITH", + "negate": false, + "whitelistMatcherData": { + "whitelist": [ + "@split.io" + ] + } + } + ] + } + } + ] + } +]}} diff --git a/tests/engine/test_bloom_filter.py b/tests/engine/test_bloom_filter.py new file mode 100644 index 00000000..303b22e0 --- /dev/null +++ b/tests/engine/test_bloom_filter.py @@ -0,0 +1,55 @@ +"""BloomFilter unit tests.""" + +from random import random +import uuid +from splitio.engine.filters import BloomFilter + +class BloomFilterTests(object): + """StandardRecorderTests test cases.""" + + def test_bloom_filter_methods(self, mocker): + bloom_filter = BloomFilter() + key1 = str(uuid.uuid4()) + key2 = str(uuid.uuid4()) + bloom_filter.add(key1) + + assert(bloom_filter.contains(key1)) + assert(not bloom_filter.contains(key2)) + + bloom_filter.clear() + assert(not bloom_filter.contains(key1)) + + bloom_filter.add(key1) + bloom_filter.add(key2) + assert(bloom_filter.contains(key1)) + assert(bloom_filter.contains(key2)) + + def test_bloom_filter_error_percentage(self, mocker): + arr_storage = [] + total_sample = 20000 + error_rate = 0.01 + bloom_filter = BloomFilter(total_sample, error_rate) + + for x in range(1, total_sample): + myuuid = str(uuid.uuid4()) + bloom_filter.add(myuuid) + arr_storage.append(myuuid) + + false_positive_count = 0 + for x in range(1, total_sample): + y = int(random()*total_sample*5) + if y > total_sample - 2: + myuuid = str(uuid.uuid4()) + if myuuid in arr_storage: + # False Negative + assert(bloom_filter.contains(myuuid)) + else: + if bloom_filter.contains(myuuid): + # False Positive + false_positive_count = false_positive_count + 1 + else: + myuuid = arr_storage[y] + assert(bloom_filter.contains(myuuid)) + # False Negative + + assert(false_positive_count/total_sample <= error_rate) \ No newline at end of file diff --git a/tests/engine/test_evaluator.py b/tests/engine/test_evaluator.py index 65bdf782..edf510c0 100644 --- a/tests/engine/test_evaluator.py +++ b/tests/engine/test_evaluator.py @@ -1,36 +1,128 @@ """Evaluator tests module.""" +import json import logging +import os +import pytest +import copy +import queue +import asyncio -from splitio.models.splits import Split +from splitio.models.splits import Split, Status, from_raw, Prerequisites +from splitio.models import segments from splitio.models.grammar.condition import Condition, ConditionType from splitio.models.impressions import Label +from splitio.models.grammar import condition +from splitio.models import rule_based_segments +from splitio.models.fallback_treatment import FallbackTreatment +from splitio.models.fallback_config import FallbackTreatmentsConfiguration, FallbackTreatmentCalculator from splitio.engine import evaluator, splitters -from splitio.storage import SplitStorage, SegmentStorage +from splitio.engine.evaluator import EvaluationContext +from splitio.storage.inmemmory import InMemorySplitStorage, InMemorySegmentStorage, InMemoryRuleBasedSegmentStorage, \ + InMemorySplitStorageAsync, InMemorySegmentStorageAsync, InMemoryRuleBasedSegmentStorageAsync +from splitio.engine.evaluator import EvaluationDataFactory, AsyncEvaluationDataFactory +rbs_raw = { + "changeNumber": 123, + "name": "sample_rule_based_segment", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded":{ + "keys":["mauro@split.io","gaston@split.io"], + "segments":[] + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user", + "attribute": "email" + }, + "matcherType": "ENDS_WITH", + "negate": False, + "whitelistMatcherData": { + "whitelist": [ + "@split.io" + ] + } + } + ] + } + } + ] +} +split_conditions = [ + condition.from_raw({ + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user" + }, + "matcherType": "IN_RULE_BASED_SEGMENT", + "negate": False, + "userDefinedSegmentMatcherData": { + "segmentName": "sample_rule_based_segment" + } + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ], + "label": "in rule based segment sample_rule_based_segment" + }), + condition.from_raw({ + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user" + }, + "matcherType": "ALL_KEYS", + "negate": False + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 0 + }, + { + "treatment": "off", + "size": 100 + } + ], + "label": "default rule" + }) +] + class EvaluatorTests(object): """Test evaluator behavior.""" def _build_evaluator_with_mocks(self, mocker): """Build an evaluator with mocked dependencies.""" - split_storage_mock = mocker.Mock(spec=SplitStorage) splitter_mock = mocker.Mock(spec=splitters.Splitter) - segment_storage_mock = mocker.Mock(spec=SegmentStorage) logger_mock = mocker.Mock(spec=logging.Logger) - e = evaluator.Evaluator(split_storage_mock, segment_storage_mock, splitter_mock) + e = evaluator.Evaluator(splitter_mock) evaluator._LOGGER = logger_mock return e - - def test_evaluate_treatment_missing_split(self, mocker): - """Test that a missing split logs and returns CONTROL.""" - e = self._build_evaluator_with_mocks(mocker) - e._split_storage.get.return_value = None - result = e.evaluate_feature('feature1', 'some_key', 'some_bucketing_key', {'attr1': 1}) - assert result['configurations'] == None - assert result['treatment'] == evaluator.CONTROL - assert result['impression']['change_number'] == -1 - assert result['impression']['label'] == Label.SPLIT_NOT_FOUND - + def test_evaluate_treatment_killed_split(self, mocker): """Test that a killed split returns the default treatment.""" e = self._build_evaluator_with_mocks(mocker) @@ -39,8 +131,10 @@ def test_evaluate_treatment_killed_split(self, mocker): mocked_split.killed = True mocked_split.change_number = 123 mocked_split.get_configurations_for.return_value = '{"some_property": 123}' - e._split_storage.get.return_value = mocked_split - result = e.evaluate_feature('feature1', 'some_key', 'some_bucketing_key', {'attr1': 1}) + mocked_split.prerequisites = [] + + ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set(), rbs_segments={}) + result = e.eval_with_context('some_key', 'some_bucketing_key', 'some', {}, ctx) assert result['treatment'] == 'off' assert result['configurations'] == '{"some_property": 123}' assert result['impression']['change_number'] == 123 @@ -50,34 +144,38 @@ def test_evaluate_treatment_killed_split(self, mocker): def test_evaluate_treatment_ok(self, mocker): """Test that a non-killed split returns the appropriate treatment.""" e = self._build_evaluator_with_mocks(mocker) - e._get_treatment_for_split = mocker.Mock() - e._get_treatment_for_split.return_value = ('on', 'some_label') + e._treatment_for_flag = mocker.Mock() + e._treatment_for_flag.return_value = ('on', 'some_label') mocked_split = mocker.Mock(spec=Split) mocked_split.default_treatment = 'off' mocked_split.killed = False mocked_split.change_number = 123 mocked_split.get_configurations_for.return_value = '{"some_property": 123}' - e._split_storage.get.return_value = mocked_split - result = e.evaluate_feature('feature1', 'some_key', 'some_bucketing_key', {'attr1': 1}) + mocked_split.prerequisites = [] + + ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set(), rbs_segments={}) + result = e.eval_with_context('some_key', 'some_bucketing_key', 'some', {}, ctx) assert result['treatment'] == 'on' assert result['configurations'] == '{"some_property": 123}' assert result['impression']['change_number'] == 123 assert result['impression']['label'] == 'some_label' assert mocked_split.get_configurations_for.mock_calls == [mocker.call('on')] - + assert result['impressions_disabled'] == mocked_split.impressions_disabled def test_evaluate_treatment_ok_no_config(self, mocker): """Test that a killed split returns the default treatment.""" e = self._build_evaluator_with_mocks(mocker) - e._get_treatment_for_split = mocker.Mock() - e._get_treatment_for_split.return_value = ('on', 'some_label') + e._treatment_for_flag = mocker.Mock() + e._treatment_for_flag.return_value = ('on', 'some_label') mocked_split = mocker.Mock(spec=Split) mocked_split.default_treatment = 'off' mocked_split.killed = False mocked_split.change_number = 123 mocked_split.get_configurations_for.return_value = None - e._split_storage.get.return_value = mocked_split - result = e.evaluate_feature('feature1', 'some_key', 'some_bucketing_key', {'attr1': 1}) + mocked_split.prerequisites = [] + + ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set(), rbs_segments={}) + result = e.eval_with_context('some_key', 'some_bucketing_key', 'some', {}, ctx) assert result['treatment'] == 'on' assert result['configurations'] == None assert result['impression']['change_number'] == 123 @@ -87,24 +185,31 @@ def test_evaluate_treatment_ok_no_config(self, mocker): def test_evaluate_treatments(self, mocker): """Test that a missing split logs and returns CONTROL.""" e = self._build_evaluator_with_mocks(mocker) - e._get_treatment_for_split = mocker.Mock() - e._get_treatment_for_split.return_value = ('on', 'some_label') + e._treatment_for_flag = mocker.Mock() + e._treatment_for_flag.return_value = ('on', 'some_label') mocked_split = mocker.Mock(spec=Split) mocked_split.name = 'feature2' mocked_split.default_treatment = 'off' mocked_split.killed = False mocked_split.change_number = 123 mocked_split.get_configurations_for.return_value = '{"some_property": 123}' - e._split_storage.fetch_many.return_value = { - 'feature1': None, - 'feature2': mocked_split, - } - results = e.evaluate_features(['feature1', 'feature2'], 'some_key', 'some_bucketing_key', None) - result = results['feature1'] + mocked_split.prerequisites = [] + + mocked_split2 = mocker.Mock(spec=Split) + mocked_split2.name = 'feature4' + mocked_split2.default_treatment = 'on' + mocked_split2.killed = False + mocked_split2.change_number = 123 + mocked_split2.get_configurations_for.return_value = None + mocked_split2.prerequisites = [] + + ctx = EvaluationContext(flags={'feature2': mocked_split, 'feature4': mocked_split2}, segment_memberships=set(), rbs_segments={}) + results = e.eval_many_with_context('some_key', 'some_bucketing_key', ['feature2', 'feature4'], {}, ctx) + result = results['feature4'] assert result['configurations'] == None - assert result['treatment'] == evaluator.CONTROL - assert result['impression']['change_number'] == -1 - assert result['impression']['label'] == Label.SPLIT_NOT_FOUND + assert result['treatment'] == 'on' + assert result['impression']['change_number'] == 123 + assert result['impression']['label'] == 'some_label' result = results['feature2'] assert result['configurations'] == '{"some_property": 123}' assert result['treatment'] == 'on' @@ -115,14 +220,19 @@ def test_get_gtreatment_for_split_no_condition_matches(self, mocker): """Test no condition matches.""" e = self._build_evaluator_with_mocks(mocker) e._splitter.get_treatment.return_value = 'on' - conditions_mock = mocker.PropertyMock() - conditions_mock.return_value = [] mocked_split = mocker.Mock(spec=Split) mocked_split.killed = False - type(mocked_split).conditions = conditions_mock - treatment, label = e._get_treatment_for_split(mocked_split, 'some_key', 'some_bucketing', {'attr1': 1}) - assert treatment == None - assert label == None + mocked_split.default_treatment = 'off' + mocked_split.change_number = '123' + mocked_split.conditions = [] + mocked_split.get_configurations_for = None + mocked_split.prerequisites = [] + + ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set(), rbs_segments={}) + assert e._treatment_for_flag(mocked_split, 'some_key', 'some_bucketing', {}, ctx) == ( + 'off', + Label.NO_CONDITION_MATCHED + ) def test_get_gtreatment_for_split_non_rollout(self, mocker): """Test condition matches.""" @@ -132,30 +242,391 @@ def test_get_gtreatment_for_split_non_rollout(self, mocker): mocked_condition_1.condition_type = ConditionType.WHITELIST mocked_condition_1.label = 'some_label' mocked_condition_1.matches.return_value = True - conditions_mock = mocker.PropertyMock() - conditions_mock.return_value = [mocked_condition_1] mocked_split = mocker.Mock(spec=Split) mocked_split.killed = False - type(mocked_split).conditions = conditions_mock - treatment, label = e._get_treatment_for_split(mocked_split, 'some_key', 'some_bucketing', {'attr1': 1}) + mocked_split.conditions = [mocked_condition_1] + mocked_split.prerequisites = [] + + treatment, label = e._treatment_for_flag(mocked_split, 'some_key', 'some_bucketing', {}, EvaluationContext(None, None, None)) assert treatment == 'on' assert label == 'some_label' - def test_get_treatment_for_split_rollout(self, mocker): - """Test rollout condition returns default treatment.""" - e = self._build_evaluator_with_mocks(mocker) - e._splitter.get_bucket.return_value = 60 - mocked_condition_1 = mocker.Mock(spec=Condition) - mocked_condition_1.condition_type = ConditionType.ROLLOUT - mocked_condition_1.label = 'some_label' - mocked_condition_1.matches.return_value = True - conditions_mock = mocker.PropertyMock() - conditions_mock.return_value = [mocked_condition_1] + def test_evaluate_treatment_with_rule_based_segment(self, mocker): + """Test that a non-killed split returns the appropriate treatment.""" + e = evaluator.Evaluator(splitters.Splitter()) + + mocked_split = Split('some', 12345, False, 'off', 'user', Status.ACTIVE, 12, split_conditions, 1.2, 100, 1234, {}, None, False, []) + + ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set(), rbs_segments={'sample_rule_based_segment': rule_based_segments.from_raw(rbs_raw)}) + result = e.eval_with_context('bilal@split.io', 'bilal@split.io', 'some', {'email': 'bilal@split.io'}, ctx) + assert result['treatment'] == 'on' + + def test_evaluate_treatment_with_rbs_in_condition(self): + e = evaluator.Evaluator(splitters.Splitter()) + events_queue = queue.Queue() + splits_storage = InMemorySplitStorage(events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + evaluation_facctory = EvaluationDataFactory(splits_storage, segment_storage, rbs_storage) + + rbs_segments = os.path.join(os.path.dirname(__file__), 'files', 'rule_base_segments.json') + with open(rbs_segments, 'r') as flo: + data = json.loads(flo.read()) + + mocked_split = Split('some', 12345, False, 'off', 'user', Status.ACTIVE, 12, split_conditions, 1.2, 100, 1234, {}, None, False, []) + rbs = rule_based_segments.from_raw(data["rbs"]["d"][0]) + rbs2 = rule_based_segments.from_raw(data["rbs"]["d"][1]) + rbs_storage.update([rbs, rbs2], [], 12) + splits_storage.update([mocked_split], [], 12) + + ctx = evaluation_facctory.context_for('bilal@split.io', ['some']) + assert e.eval_with_context('bilal@split.io', 'bilal@split.io', 'some', {'email': 'bilal@split.io'}, ctx)['treatment'] == "on" + + ctx = evaluation_facctory.context_for('mauro@split.io', ['some']) + assert e.eval_with_context('mauro@split.io', 'mauro@split.io', 'some', {'email': 'mauro@split.io'}, ctx)['treatment'] == "off" + + def test_using_segment_in_excluded(self): + rbs_segments = os.path.join(os.path.dirname(__file__), 'files', 'rule_base_segments3.json') + with open(rbs_segments, 'r') as flo: + data = json.loads(flo.read()) + e = evaluator.Evaluator(splitters.Splitter()) + events_queue = queue.Queue() + splits_storage = InMemorySplitStorage(events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + evaluation_facctory = EvaluationDataFactory(splits_storage, segment_storage, rbs_storage) + + mocked_split = Split('some', 12345, False, 'off', 'user', Status.ACTIVE, 12, split_conditions, 1.2, 100, 1234, {}, None, False, []) + rbs = rule_based_segments.from_raw(data["rbs"]["d"][0]) + rbs_storage.update([rbs], [], 12) + splits_storage.update([mocked_split], [], 12) + segment = segments.from_raw({'name': 'segment1', 'added': ['pato@split.io'], 'removed': [], 'till': 123}) + segment_storage.put(segment) + + ctx = evaluation_facctory.context_for('bilal@split.io', ['some']) + assert e.eval_with_context('bilal@split.io', 'bilal@split.io', 'some', {'email': 'bilal@split.io'}, ctx)['treatment'] == "on" + ctx = evaluation_facctory.context_for('mauro@split.io', ['some']) + assert e.eval_with_context('mauro@split.io', 'mauro@split.io', 'some', {'email': 'mauro@split.io'}, ctx)['treatment'] == "off" + ctx = evaluation_facctory.context_for('pato@split.io', ['some']) + assert e.eval_with_context('pato@split.io', 'pato@split.io', 'some', {'email': 'pato@split.io'}, ctx)['treatment'] == "off" + + def test_using_rbs_in_excluded(self): + rbs_segments = os.path.join(os.path.dirname(__file__), 'files', 'rule_base_segments2.json') + with open(rbs_segments, 'r') as flo: + data = json.loads(flo.read()) + e = evaluator.Evaluator(splitters.Splitter()) + events_queue = queue.Queue() + splits_storage = InMemorySplitStorage(events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + evaluation_facctory = EvaluationDataFactory(splits_storage, segment_storage, rbs_storage) + + mocked_split = Split('some', 12345, False, 'off', 'user', Status.ACTIVE, 12, split_conditions, 1.2, 100, 1234, {}, None, False, []) + rbs = rule_based_segments.from_raw(data["rbs"]["d"][0]) + rbs2 = rule_based_segments.from_raw(data["rbs"]["d"][1]) + rbs_storage.update([rbs, rbs2], [], 12) + splits_storage.update([mocked_split], [], 12) + + ctx = evaluation_facctory.context_for('bilal@split.io', ['some']) + assert e.eval_with_context('bilal@split.io', 'bilal@split.io', 'some', {'email': 'bilal@split.io'}, ctx)['treatment'] == "off" + ctx = evaluation_facctory.context_for('bilal', ['some']) + assert e.eval_with_context('bilal', 'bilal', 'some', {'email': 'bilal'}, ctx)['treatment'] == "on" + ctx = evaluation_facctory.context_for('bilal2@split.io', ['some']) + assert e.eval_with_context('bilal2@split.io', 'bilal2@split.io', 'some', {'email': 'bilal2@split.io'}, ctx)['treatment'] == "on" + + def test_prerequisites(self): + splits_load = os.path.join(os.path.dirname(__file__), '../models/grammar/files', 'splits_prereq.json') + with open(splits_load, 'r') as flo: + data = json.loads(flo.read()) + e = evaluator.Evaluator(splitters.Splitter()) + events_queue = queue.Queue() + splits_storage = InMemorySplitStorage(events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + evaluation_facctory = EvaluationDataFactory(splits_storage, segment_storage, rbs_storage) + + rbs = rule_based_segments.from_raw(data["rbs"]["d"][0]) + split1 = from_raw(data["ff"]["d"][0]) + split2 = from_raw(data["ff"]["d"][1]) + split3 = from_raw(data["ff"]["d"][2]) + split4 = from_raw(data["ff"]["d"][3]) + rbs_storage.update([rbs], [], 12) + splits_storage.update([split1, split2, split3, split4], [], 12) + segment = segments.from_raw({'name': 'segment-test', 'added': ['pato@split.io'], 'removed': [], 'till': 123}) + segment_storage.put(segment) + + ctx = evaluation_facctory.context_for('bilal@split.io', ['test_prereq']) + assert e.eval_with_context('bilal@split.io', 'bilal@split.io', 'test_prereq', {'email': 'bilal@split.io'}, ctx)['treatment'] == "on" + assert e.eval_with_context('bilal@split.io', 'bilal@split.io', 'test_prereq', {}, ctx)['treatment'] == "def_treatment" + + ctx = evaluation_facctory.context_for('mauro@split.io', ['test_prereq']) + assert e.eval_with_context('mauro@split.io', 'mauro@split.io', 'test_prereq', {'email': 'mauro@split.io'}, ctx)['treatment'] == "def_treatment" + + ctx = evaluation_facctory.context_for('pato@split.io', ['test_prereq']) + assert e.eval_with_context('pato@split.io', 'pato@split.io', 'test_prereq', {'email': 'pato@split.io'}, ctx)['treatment'] == "def_treatment" + + ctx = evaluation_facctory.context_for('nico@split.io', ['test_prereq']) + assert e.eval_with_context('nico@split.io', 'nico@split.io', 'test_prereq', {'email': 'nico@split.io'}, ctx)['treatment'] == "on" + + ctx = evaluation_facctory.context_for('bilal@split.io', ['prereq_chain']) + assert e.eval_with_context('bilal@split.io', 'bilal@split.io', 'prereq_chain', {'email': 'bilal@split.io'}, ctx)['treatment'] == "on_whitelist" + + ctx = evaluation_facctory.context_for('nico@split.io', ['prereq_chain']) + assert e.eval_with_context('nico@split.io', 'nico@split.io', 'test_prereq', {'email': 'nico@split.io'}, ctx)['treatment'] == "on" + + ctx = evaluation_facctory.context_for('pato@split.io', ['prereq_chain']) + assert e.eval_with_context('pato@split.io', 'pato@split.io', 'prereq_chain', {'email': 'pato@split.io'}, ctx)['treatment'] == "on_default" + + ctx = evaluation_facctory.context_for('mauro@split.io', ['prereq_chain']) + assert e.eval_with_context('mauro@split.io', 'mauro@split.io', 'prereq_chain', {'email': 'mauro@split.io'}, ctx)['treatment'] == "on_default" + + def test_evaluate_treatment_with_fallback(self, mocker): + """Test that a evaluation return fallback treatment.""" + splitter_mock = mocker.Mock(spec=splitters.Splitter) + logger_mock = mocker.Mock(spec=logging.Logger) + evaluator._LOGGER = logger_mock mocked_split = mocker.Mock(spec=Split) - mocked_split.traffic_allocation = 50 - mocked_split.default_treatment = 'almost-on' - mocked_split.killed = False - type(mocked_split).conditions = conditions_mock - treatment, label = e._get_treatment_for_split(mocked_split, 'some_key', 'some_bucketing', {'attr1': 1}) - assert treatment == 'almost-on' - assert label == Label.NOT_IN_SPLIT + ctx = EvaluationContext(flags={'some': mocked_split}, segment_memberships=set(), rbs_segments={}) + + # should use global fallback + e = evaluator.Evaluator(splitter_mock, FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("off-global", '{"prop": "val"}')))) + result = e.eval_with_context('some_key', 'some_bucketing_key', 'some2', {}, ctx) + assert result['treatment'] == 'off-global' + assert result['configurations'] == '{"prop": "val"}' + assert result['impression']['label'] == "fallback - " + Label.SPLIT_NOT_FOUND + + # should use by flag fallback + e = evaluator.Evaluator(splitter_mock, FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {"some2": FallbackTreatment("off-some2", '{"prop2": "val2"}')}))) + result = e.eval_with_context('some_key', 'some_bucketing_key', 'some2', {}, ctx) + assert result['treatment'] == 'off-some2' + assert result['configurations'] == '{"prop2": "val2"}' + assert result['impression']['label'] == "fallback - " + Label.SPLIT_NOT_FOUND + + # should not use any fallback + e = evaluator.Evaluator(splitter_mock, FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {"some2": FallbackTreatment("off-some2", '{"prop2": "val2"}')}))) + result = e.eval_with_context('some_key', 'some_bucketing_key', 'some3', {}, ctx) + assert result['treatment'] == 'control' + assert result['configurations'] == None + assert result['impression']['label'] == Label.SPLIT_NOT_FOUND + + # should use by flag fallback + e = evaluator.Evaluator(splitter_mock, FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("off-global", '{"prop": "val"}'), {"some2": FallbackTreatment("off-some2", '{"prop2": "val2"}')}))) + result = e.eval_with_context('some_key', 'some_bucketing_key', 'some2', {}, ctx) + assert result['treatment'] == 'off-some2' + assert result['configurations'] == '{"prop2": "val2"}' + assert result['impression']['label'] == "fallback - " + Label.SPLIT_NOT_FOUND + + # should global flag fallback + e = evaluator.Evaluator(splitter_mock, FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("off-global", '{"prop": "val"}'), {"some2": FallbackTreatment("off-some2", '{"prop2": "val2"}')}))) + result = e.eval_with_context('some_key', 'some_bucketing_key', 'some3', {}, ctx) + assert result['treatment'] == 'off-global' + assert result['configurations'] == '{"prop": "val"}' + assert result['impression']['label'] == "fallback - " + Label.SPLIT_NOT_FOUND + + @pytest.mark.asyncio + async def test_evaluate_treatment_with_rbs_in_condition_async(self): + e = evaluator.Evaluator(splitters.Splitter()) + internal_events_queue = asyncio.Queue() + + splits_storage = InMemorySplitStorageAsync(internal_events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + evaluation_facctory = AsyncEvaluationDataFactory(splits_storage, segment_storage, rbs_storage) + + rbs_segments = os.path.join(os.path.dirname(__file__), 'files', 'rule_base_segments.json') + with open(rbs_segments, 'r') as flo: + data = json.loads(flo.read()) + + mocked_split = Split('some', 12345, False, 'off', 'user', Status.ACTIVE, 12, split_conditions, 1.2, 100, 1234, {}, None, False) + rbs = rule_based_segments.from_raw(data["rbs"]["d"][0]) + rbs2 = rule_based_segments.from_raw(data["rbs"]["d"][1]) + await rbs_storage.update([rbs, rbs2], [], 12) + await splits_storage.update([mocked_split], [], 12) + + ctx = await evaluation_facctory.context_for('bilal@split.io', ['some']) + assert e.eval_with_context('bilal@split.io', 'bilal@split.io', 'some', {'email': 'bilal@split.io'}, ctx)['treatment'] == "on" + ctx = await evaluation_facctory.context_for('mauro@split.io', ['some']) + assert e.eval_with_context('mauro@split.io', 'mauro@split.io', 'some', {'email': 'mauro@split.io'}, ctx)['treatment'] == "off" + + @pytest.mark.asyncio + async def test_using_segment_in_excluded_async(self): + rbs_segments = os.path.join(os.path.dirname(__file__), 'files', 'rule_base_segments3.json') + with open(rbs_segments, 'r') as flo: + data = json.loads(flo.read()) + e = evaluator.Evaluator(splitters.Splitter()) + internal_events_queue = asyncio.Queue() + splits_storage = InMemorySplitStorageAsync(internal_events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + evaluation_facctory = AsyncEvaluationDataFactory(splits_storage, segment_storage, rbs_storage) + + mocked_split = Split('some', 12345, False, 'off', 'user', Status.ACTIVE, 12, split_conditions, 1.2, 100, 1234, {}, None, False) + rbs = rule_based_segments.from_raw(data["rbs"]["d"][0]) + await rbs_storage.update([rbs], [], 12) + await splits_storage.update([mocked_split], [], 12) + segment = segments.from_raw({'name': 'segment1', 'added': ['pato@split.io'], 'removed': [], 'till': 123}) + await segment_storage.put(segment) + + ctx = await evaluation_facctory.context_for('bilal@split.io', ['some']) + assert e.eval_with_context('bilal@split.io', 'bilal@split.io', 'some', {'email': 'bilal@split.io'}, ctx)['treatment'] == "on" + ctx = await evaluation_facctory.context_for('mauro@split.io', ['some']) + assert e.eval_with_context('mauro@split.io', 'mauro@split.io', 'some', {'email': 'mauro@split.io'}, ctx)['treatment'] == "off" + ctx = await evaluation_facctory.context_for('pato@split.io', ['some']) + assert e.eval_with_context('pato@split.io', 'pato@split.io', 'some', {'email': 'pato@split.io'}, ctx)['treatment'] == "off" + + @pytest.mark.asyncio + async def test_using_rbs_in_excluded_async(self): + rbs_segments = os.path.join(os.path.dirname(__file__), 'files', 'rule_base_segments2.json') + with open(rbs_segments, 'r') as flo: + data = json.loads(flo.read()) + e = evaluator.Evaluator(splitters.Splitter()) + internal_events_queue = asyncio.Queue() + splits_storage = InMemorySplitStorageAsync(internal_events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + evaluation_facctory = AsyncEvaluationDataFactory(splits_storage, segment_storage, rbs_storage) + + mocked_split = Split('some', 12345, False, 'off', 'user', Status.ACTIVE, 12, split_conditions, 1.2, 100, 1234, {}, None, False) + rbs = rule_based_segments.from_raw(data["rbs"]["d"][0]) + rbs2 = rule_based_segments.from_raw(data["rbs"]["d"][1]) + await rbs_storage.update([rbs, rbs2], [], 12) + await splits_storage.update([mocked_split], [], 12) + + ctx = await evaluation_facctory.context_for('bilal@split.io', ['some']) + assert e.eval_with_context('bilal@split.io', 'bilal@split.io', 'some', {'email': 'bilal@split.io'}, ctx)['treatment'] == "off" + ctx = await evaluation_facctory.context_for('bilal', ['some']) + assert e.eval_with_context('bilal', 'bilal', 'some', {'email': 'bilal'}, ctx)['treatment'] == "on" + ctx = await evaluation_facctory.context_for('bilal2@split.io', ['some']) + assert e.eval_with_context('bilal2@split.io', 'bilal2@split.io', 'some', {'email': 'bilal2@split.io'}, ctx)['treatment'] == "on" + + @pytest.mark.asyncio + async def test_prerequisites(self): + splits_load = os.path.join(os.path.dirname(__file__), '../models/grammar/files', 'splits_prereq.json') + with open(splits_load, 'r') as flo: + data = json.loads(flo.read()) + e = evaluator.Evaluator(splitters.Splitter()) + internal_events_queue = asyncio.Queue() + splits_storage = InMemorySplitStorageAsync(internal_events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + evaluation_facctory = AsyncEvaluationDataFactory(splits_storage, segment_storage, rbs_storage) + + rbs = rule_based_segments.from_raw(data["rbs"]["d"][0]) + split1 = from_raw(data["ff"]["d"][0]) + split2 = from_raw(data["ff"]["d"][1]) + split3 = from_raw(data["ff"]["d"][2]) + split4 = from_raw(data["ff"]["d"][3]) + await rbs_storage.update([rbs], [], 12) + await splits_storage.update([split1, split2, split3, split4], [], 12) + segment = segments.from_raw({'name': 'segment-test', 'added': ['pato@split.io'], 'removed': [], 'till': 123}) + await segment_storage.put(segment) + + ctx = await evaluation_facctory.context_for('bilal@split.io', ['test_prereq']) + assert e.eval_with_context('bilal@split.io', 'bilal@split.io', 'test_prereq', {'email': 'bilal@split.io'}, ctx)['treatment'] == "on" + assert e.eval_with_context('bilal@split.io', 'bilal@split.io', 'test_prereq', {}, ctx)['treatment'] == "def_treatment" + + ctx = await evaluation_facctory.context_for('mauro@split.io', ['test_prereq']) + assert e.eval_with_context('mauro@split.io', 'mauro@split.io', 'test_prereq', {'email': 'mauro@split.io'}, ctx)['treatment'] == "def_treatment" + + ctx = await evaluation_facctory.context_for('pato@split.io', ['test_prereq']) + assert e.eval_with_context('pato@split.io', 'pato@split.io', 'test_prereq', {'email': 'pato@split.io'}, ctx)['treatment'] == "def_treatment" + + ctx = await evaluation_facctory.context_for('nico@split.io', ['test_prereq']) + assert e.eval_with_context('nico@split.io', 'nico@split.io', 'test_prereq', {'email': 'nico@split.io'}, ctx)['treatment'] == "on" + + ctx = await evaluation_facctory.context_for('bilal@split.io', ['prereq_chain']) + assert e.eval_with_context('bilal@split.io', 'bilal@split.io', 'prereq_chain', {'email': 'bilal@split.io'}, ctx)['treatment'] == "on_whitelist" + + ctx = await evaluation_facctory.context_for('nico@split.io', ['prereq_chain']) + assert e.eval_with_context('nico@split.io', 'nico@split.io', 'test_prereq', {'email': 'nico@split.io'}, ctx)['treatment'] == "on" + + ctx = await evaluation_facctory.context_for('pato@split.io', ['prereq_chain']) + assert e.eval_with_context('pato@split.io', 'pato@split.io', 'prereq_chain', {'email': 'pato@split.io'}, ctx)['treatment'] == "on_default" + + ctx = await evaluation_facctory.context_for('mauro@split.io', ['prereq_chain']) + assert e.eval_with_context('mauro@split.io', 'mauro@split.io', 'prereq_chain', {'email': 'mauro@split.io'}, ctx)['treatment'] == "on_default" + +class EvaluationDataFactoryTests(object): + """Test evaluation factory class.""" + + def test_get_context(self): + """Test context.""" + mocked_split = Split('some', 12345, False, 'off', 'user', Status.ACTIVE, 12, split_conditions, 1.2, 100, 1234, {}, None, False, [Prerequisites('split2', ['on'])]) + split2 = Split('split2', 12345, False, 'off', 'user', Status.ACTIVE, 12, split_conditions, 1.2, 100, 1234, {}, None, False, []) + events_queue = queue.Queue() + flag_storage = InMemorySplitStorage(events_queue, []) + segment_storage = InMemorySegmentStorage(events_queue) + rbs_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + flag_storage.update([mocked_split, split2], [], -1) + rbs = copy.deepcopy(rbs_raw) + rbs['conditions'].append( + {"matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SEGMENT", + "negate": False, + "userDefinedSegmentMatcherData": { + "segmentName": "employees" + }, + "whitelistMatcherData": None + } + ] + }, + }) + rbs = rule_based_segments.from_raw(rbs) + rbs_segment_storage.update([rbs], [], -1) + + eval_factory = EvaluationDataFactory(flag_storage, segment_storage, rbs_segment_storage) + ec = eval_factory.context_for('bilal@split.io', ['some']) + assert ec.rbs_segments == {'sample_rule_based_segment': rbs} + assert ec.segment_memberships == {"employees": False} + assert ec.flags.get("split2").name == "split2" + + segment_storage.update("employees", {"mauro@split.io"}, {}, 1234) + ec = eval_factory.context_for('mauro@split.io', ['some']) + assert ec.rbs_segments == {'sample_rule_based_segment': rbs} + assert ec.segment_memberships == {"employees": True} + +class EvaluationDataFactoryAsyncTests(object): + """Test evaluation factory class.""" + + @pytest.mark.asyncio + async def test_get_context(self): + """Test context.""" + mocked_split = Split('some', 12345, False, 'off', 'user', Status.ACTIVE, 12, split_conditions, 1.2, 100, 1234, {}, None, False, [Prerequisites('split2', ['on'])]) + split2 = Split('split2', 12345, False, 'off', 'user', Status.ACTIVE, 12, split_conditions, 1.2, 100, 1234, {}, None, False, []) + internal_events_queue = asyncio.Queue() + flag_storage = InMemorySplitStorageAsync(internal_events_queue, []) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rbs_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + await flag_storage.update([mocked_split, split2], [], -1) + rbs = copy.deepcopy(rbs_raw) + rbs['conditions'].append( + {"matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SEGMENT", + "negate": False, + "userDefinedSegmentMatcherData": { + "segmentName": "employees" + }, + "whitelistMatcherData": None + } + ] + }, + }) + rbs = rule_based_segments.from_raw(rbs) + await rbs_segment_storage.update([rbs], [], -1) + + eval_factory = AsyncEvaluationDataFactory(flag_storage, segment_storage, rbs_segment_storage) + ec = await eval_factory.context_for('bilal@split.io', ['some']) + assert ec.rbs_segments == {'sample_rule_based_segment': rbs} + assert ec.segment_memberships == {"employees": False} + assert ec.flags.get("split2").name == "split2" + + await segment_storage.update("employees", {"mauro@split.io"}, {}, 1234) + ec = await eval_factory.context_for('mauro@split.io', ['some']) + assert ec.rbs_segments == {'sample_rule_based_segment': rbs} + assert ec.segment_memberships == {"employees": True} diff --git a/tests/engine/test_impressions.py b/tests/engine/test_impressions.py index c1d43468..715bfe1b 100644 --- a/tests/engine/test_impressions.py +++ b/tests/engine/test_impressions.py @@ -1,10 +1,15 @@ """Impression manager, observer & hasher tests.""" from datetime import datetime -from splitio.engine.impressions import Hasher, Observer, Counter, Manager, \ - ImpressionsMode, truncate_time -from splitio.models.impressions import Impression +import unittest.mock as mock +import pytest +from splitio.engine.impressions.impressions import Manager, ImpressionsMode +from splitio.engine.impressions.manager import Hasher, Observer, Counter, truncate_time +from splitio.engine.impressions.strategies import StrategyDebugMode, StrategyOptimizedMode, StrategyNoneMode +from splitio.models.impressions import Impression, ImpressionDecorated from splitio.client.listener import ImpressionListenerWrapper - +import splitio.models.telemetry as ModelTelemetry +from splitio.engine.telemetry import TelemetryStorageProducer +from splitio.storage.inmemmory import InMemoryTelemetryStorage def utctime_ms_reimplement(): """Re-implementation of utctime_ms to avoid conflicts with mock/patching.""" @@ -18,16 +23,16 @@ def test_changes_are_reflected(self): """Test that change in any field changes the resulting hash.""" total = set() hasher = Hasher() - total.add(hasher.process(Impression('key1', 'feature1', 'on', 'killed', 123, None, 456))) - total.add(hasher.process(Impression('key2', 'feature1', 'on', 'killed', 123, None, 456))) - total.add(hasher.process(Impression('key1', 'feature2', 'on', 'killed', 123, None, 456))) - total.add(hasher.process(Impression('key1', 'feature1', 'off', 'killed', 123, None, 456))) - total.add(hasher.process(Impression('key1', 'feature1', 'on', 'not killed', 123, None, 456))) - total.add(hasher.process(Impression('key1', 'feature1', 'on', 'killed', 321, None, 456))) + total.add(hasher.process(Impression('key1', 'feature1', 'on', 'killed', 123, None, 456, None, {}))) + total.add(hasher.process(Impression('key2', 'feature1', 'on', 'killed', 123, None, 456, None, {}))) + total.add(hasher.process(Impression('key1', 'feature2', 'on', 'killed', 123, None, 456, None, {}))) + total.add(hasher.process(Impression('key1', 'feature1', 'off', 'killed', 123, None, 456, None, {}))) + total.add(hasher.process(Impression('key1', 'feature1', 'on', 'not killed', 123, None, 456, None, {}))) + total.add(hasher.process(Impression('key1', 'feature1', 'on', 'killed', 321, None, 456, None, {}))) assert len(total) == 6 # Re-adding the first-one should not increase the number of different hashes - total.add(hasher.process(Impression('key1', 'feature1', 'on', 'killed', 123, None, 456))) + total.add(hasher.process(Impression('key1', 'feature1', 'on', 'killed', 123, None, 456, None, {}))) assert len(total) == 6 @@ -37,26 +42,26 @@ class ImpressionObserverTests(object): def test_previous_time_properly_calculated(self): """Test that the previous time is properly set.""" observer = Observer(5) - assert (observer.test_and_set(Impression('key1', 'f1', 'on', 'killed', 123, None, 456)) - == Impression('key1', 'f1', 'on', 'killed', 123, None, 456)) - assert (observer.test_and_set(Impression('key1', 'f1', 'on', 'killed', 123, None, 457)) - == Impression('key1', 'f1', 'on', 'killed', 123, None, 457, 456)) + assert (observer.test_and_set(Impression('key1', 'f1', 'on', 'killed', 123, None, 456, None, None)) + == Impression('key1', 'f1', 'on', 'killed', 123, None, 456, None, None)) + assert (observer.test_and_set(Impression('key1', 'f1', 'on', 'killed', 123, None, 457, None, None)) + == Impression('key1', 'f1', 'on', 'killed', 123, None, 457, 456, None)) # Add 5 new impressions to evict the first one and check that previous time is None again - assert (observer.test_and_set(Impression('key2', 'f1', 'on', 'killed', 123, None, 456)) - == Impression('key2', 'f1', 'on', 'killed', 123, None, 456)) - assert (observer.test_and_set(Impression('key3', 'f1', 'on', 'killed', 123, None, 456)) - == Impression('key3', 'f1', 'on', 'killed', 123, None, 456)) - assert (observer.test_and_set(Impression('key4', 'f1', 'on', 'killed', 123, None, 456)) - == Impression('key4', 'f1', 'on', 'killed', 123, None, 456)) - assert (observer.test_and_set(Impression('key5', 'f1', 'on', 'killed', 123, None, 456)) - == Impression('key5', 'f1', 'on', 'killed', 123, None, 456)) - assert (observer.test_and_set(Impression('key6', 'f1', 'on', 'killed', 123, None, 456)) - == Impression('key6', 'f1', 'on', 'killed', 123, None, 456)) + assert (observer.test_and_set(Impression('key2', 'f1', 'on', 'killed', 123, None, 456, None, None)) + == Impression('key2', 'f1', 'on', 'killed', 123, None, 456, None, None)) + assert (observer.test_and_set(Impression('key3', 'f1', 'on', 'killed', 123, None, 456, None, None)) + == Impression('key3', 'f1', 'on', 'killed', 123, None, 456, None, None)) + assert (observer.test_and_set(Impression('key4', 'f1', 'on', 'killed', 123, None, 456, None, None)) + == Impression('key4', 'f1', 'on', 'killed', 123, None, 456, None, None)) + assert (observer.test_and_set(Impression('key5', 'f1', 'on', 'killed', 123, None, 456, None, None)) + == Impression('key5', 'f1', 'on', 'killed', 123, None, 456, None, None)) + assert (observer.test_and_set(Impression('key6', 'f1', 'on', 'killed', 123, None, 456, None, None)) + == Impression('key6', 'f1', 'on', 'killed', 123, None, 456, None, None)) # Re-process the first-one - assert (observer.test_and_set(Impression('key1', 'f1', 'on', 'killed', 123, None, 456)) - == Impression('key1', 'f1', 'on', 'killed', 123, None, 456)) + assert (observer.test_and_set(Impression('key1', 'f1', 'on', 'killed', 123, None, 456, None, None)) + == Impression('key1', 'f1', 'on', 'killed', 123, None, 456, None, None)) class ImpressionCounterTests(object): @@ -67,15 +72,15 @@ def test_tracking_and_popping(self): counter = Counter() utc_now = utctime_ms_reimplement() utc_1_hour_after = utc_now + (3600 * 1000) - counter.track([Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now), - Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now), - Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now)]) + counter.track([Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now, None, None), + Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now, None, None), + Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now, None, None)]) - counter.track([Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now), - Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now)]) + counter.track([Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now, None, None)]) - counter.track([Impression('k1', 'f1', 'on', 'l1', 123, None, utc_1_hour_after), - Impression('k1', 'f2', 'on', 'l1', 123, None, utc_1_hour_after)]) + counter.track([Impression('k1', 'f1', 'on', 'l1', 123, None, utc_1_hour_after, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, utc_1_hour_after, None, None)]) assert set(counter.pop_all()) == set([ Counter.CountPerFeature('f1', truncate_time(utc_now), 3), @@ -85,7 +90,6 @@ def test_tracking_and_popping(self): assert len(counter._data) == 0 assert set(counter.pop_all()) == set() - class ImpressionManagerTests(object): """Test impressions manager in all of its configurations.""" @@ -96,434 +100,570 @@ def test_standalone_optimized(self, mocker): utc_now = truncate_time(utctime_ms_reimplement()) + 1800 * 1000 utc_time_mock = mocker.Mock() utc_time_mock.return_value = utc_now - mocker.patch('splitio.util.utctime_ms', new=utc_time_mock) + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() - manager = Manager() # no listener - assert manager._counter is not None - assert manager._observer is not None - assert manager._listener is None + manager = Manager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + assert manager._strategy._observer is not None + assert isinstance(manager._strategy, StrategyOptimizedMode) + assert isinstance(manager._none_strategy, StrategyNoneMode) # An impression that hasn't happened in the last hour (pt = None) should be tracked - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), None), - (Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), False), None), + (ImpressionDecorated(Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), - Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3)] + + assert for_unique_keys_tracker == [] + assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None)] + assert deduped == 0 # Tracking the same impression a ms later should be empty - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), False), None) ]) assert imps == [] + assert deduped == 1 + assert for_unique_keys_tracker == [] # Tracking an impression with a different key makes it to the queue - imps = manager.process_impressions([ - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), False), None) ]) - assert imps == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1)] + assert imps == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None)] + assert deduped == 0 # Advance the perceived clock one hour old_utc = utc_now # save it to compare captured impressions utc_now += 3600 * 1000 utc_time_mock.return_value = utc_now + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) # Track the same impressions but "one hour later" - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1), None), - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), False), None), + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, old_utc-3), - Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1)] + assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, old_utc-3, None), + Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1, None)] + assert deduped == 0 + assert for_unique_keys_tracker == [] - assert len(manager._observer._cache._data) == 3 # distinct impressions seen - assert len(manager._counter._data) == 3 # 2 distinct features. 1 seen in 2 different timeframes + assert len(manager._strategy._observer._cache._data) == 3 # distinct impressions seen + assert for_counter == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, old_utc-3, None), + Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1, None)] - assert set(manager._counter.pop_all()) == set([ - Counter.CountPerFeature('f1', truncate_time(old_utc), 3), - Counter.CountPerFeature('f2', truncate_time(old_utc), 1), - Counter.CountPerFeature('f1', truncate_time(utc_now), 2) + # Test counting only from the second impression + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k3', 'f3', 'on', 'l1', 123, None, utc_now-1, None, None), False), None) ]) + assert for_counter == [] + assert deduped == 0 + assert for_unique_keys_tracker == [] + + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k3', 'f3', 'on', 'l1', 123, None, utc_now-1, None, None), False), None) + ]) + assert for_counter == [Impression('k3', 'f3', 'on', 'l1', 123, None, utc_now-1, utc_now-1, None)] + assert deduped == 1 + assert for_unique_keys_tracker == [] def test_standalone_debug(self, mocker): - """Test impressions manager in optimized mode with sdk in standalone mode.""" + """Test impressions manager in debug mode with sdk in standalone mode.""" # Mock utc_time function to be able to play with the clock utc_now = truncate_time(utctime_ms_reimplement()) + 1800 * 1000 utc_time_mock = mocker.Mock() utc_time_mock.return_value = utc_now - mocker.patch('splitio.util.utctime_ms', new=utc_time_mock) + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) - manager = Manager(ImpressionsMode.DEBUG) # no listener - assert manager._counter is None - assert manager._observer is not None - assert manager._listener is None + manager = Manager(StrategyDebugMode(), StrategyNoneMode(), mocker.Mock()) # no listener + assert manager._strategy._observer is not None + assert isinstance(manager._strategy, StrategyDebugMode) # An impression that hasn't happened in the last hour (pt = None) should be tracked - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), None), - (Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), False), None), + (ImpressionDecorated(Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), - Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3)] + assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None)] + assert for_counter == [] + assert for_unique_keys_tracker == [] # Tracking the same impression a ms later should return the impression - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2, utc_now-3)] + assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2, utc_now-3, None)] + assert for_counter == [] + assert for_unique_keys_tracker == [] # Tracking a in impression with a different key makes it to the queue - imps = manager.process_impressions([ - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), False), None) ]) - assert imps == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1)] + assert imps == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None)] + assert for_counter == [] + assert for_unique_keys_tracker == [] # Advance the perceived clock one hour old_utc = utc_now # save it to compare captured impressions utc_now += 3600 * 1000 utc_time_mock.return_value = utc_now + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) # Track the same impressions but "one hour later" - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1), None), - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), False), None), + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, old_utc-3), - Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1)] + assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, old_utc-3, None), + Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1, None)] + assert for_counter == [] + assert for_unique_keys_tracker == [] - assert len(manager._observer._cache._data) == 3 # distinct impressions seen + assert len(manager._strategy._observer._cache._data) == 3 # distinct impressions seen - def test_non_standalone_optimized(self, mocker): - """Test impressions manager in optimized mode with sdk in standalone mode.""" + def test_standalone_none(self, mocker): + """Test impressions manager in none mode with sdk in standalone mode.""" # Mock utc_time function to be able to play with the clock utc_now = truncate_time(utctime_ms_reimplement()) + 1800 * 1000 utc_time_mock = mocker.Mock() utc_time_mock.return_value = utc_now - mocker.patch('splitio.util.utctime_ms', new=utc_time_mock) + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) - manager = Manager(ImpressionsMode.OPTIMIZED, False) # no listener - assert manager._counter is None - assert manager._observer is None - assert manager._listener is None + manager = Manager(StrategyNoneMode(), StrategyNoneMode(), mocker.Mock()) # no listener + assert isinstance(manager._strategy, StrategyNoneMode) - # An impression that hasn't happened in the last hour (pt = None) should be tracked - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), None), - (Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3), None) + # no impressions are tracked, only counter and mtk + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), False), None), + (ImpressionDecorated(Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), - Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3)] + assert imps == [] + assert for_counter == [ + Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None) + ] + assert for_unique_keys_tracker == [('k1', 'f1'), ('k1', 'f2')] - # Tracking the same impression a ms later should not be empty - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2), None) + # Tracking the same impression a ms later should not return the impression and no change on mtk cache + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2)] + assert imps == [] - # Tracking a in impression with a different key makes it to the queue - imps = manager.process_impressions([ - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1), None) + # Tracking an impression with a different key, will only increase mtk + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k3', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), False), None) ]) - assert imps == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1)] + assert imps == [] + assert for_unique_keys_tracker == [('k3', 'f1')] + assert for_counter == [ + Impression('k3', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None) + ] # Advance the perceived clock one hour + old_utc = utc_now # save it to compare captured impressions utc_now += 3600 * 1000 utc_time_mock.return_value = utc_now + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) - # Track the same impressions but "one hour later" - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1), None), - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2), None) + # Track the same impressions but "one hour later", no changes on mtk + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), False), None), + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1), - Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2)] + assert imps == [] + assert for_counter == [ + Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), + Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None) + ] - def test_non_standalone_debug(self, mocker): + def test_standalone_optimized_listener(self, mocker): """Test impressions manager in optimized mode with sdk in standalone mode.""" # Mock utc_time function to be able to play with the clock utc_now = truncate_time(utctime_ms_reimplement()) + 1800 * 1000 utc_time_mock = mocker.Mock() utc_time_mock.return_value = utc_now - mocker.patch('splitio.util.utctime_ms', new=utc_time_mock) +# mocker.patch('splitio.util.time.utctime_ms', return_value=utc_time_mock) + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) - manager = Manager(ImpressionsMode.DEBUG, False) # no listener - assert manager._counter is None - assert manager._observer is None - assert manager._listener is None + manager = Manager(StrategyOptimizedMode(), StrategyNoneMode(), mocker.Mock()) + assert manager._strategy._observer is not None + assert isinstance(manager._strategy, StrategyOptimizedMode) # An impression that hasn't happened in the last hour (pt = None) should be tracked - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), None), - (Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), False), None), + (ImpressionDecorated(Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), - Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3)] + assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None)] + assert deduped == 0 + assert listen == [(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), None), + (Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None), None)] + assert for_unique_keys_tracker == [] - # Tracking the same impression a ms later should not be empty - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2), None) + # Tracking the same impression a ms later should return empty + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2)] + assert imps == [] + assert deduped == 1 + assert listen == [(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2, utc_now-3, None), None)] + assert for_unique_keys_tracker == [] # Tracking a in impression with a different key makes it to the queue - imps = manager.process_impressions([ - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), False), None) ]) - assert imps == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1)] + assert imps == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None)] + assert deduped == 0 + assert listen == [(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), None)] + assert for_unique_keys_tracker == [] # Advance the perceived clock one hour + old_utc = utc_now # save it to compare captured impressions utc_now += 3600 * 1000 utc_time_mock.return_value = utc_now + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) # Track the same impressions but "one hour later" - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1), None), - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), False), None), + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), False), None) + ]) + assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, old_utc-3, None), + Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1, None)] + assert deduped == 0 + assert listen == [ + (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, old_utc-3, None), None), + (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1, None), None), + ] + assert for_unique_keys_tracker == [] + assert len(manager._strategy._observer._cache._data) == 3 # distinct impressions seen + assert for_counter == [ + Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, old_utc-3, None), + Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1, None) + ] + + # Test counting only from the second impression + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k3', 'f3', 'on', 'l1', 123, None, utc_now-1, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1), - Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2)] + assert for_counter == [] + assert deduped == 0 + assert for_unique_keys_tracker == [] - def test_standalone_optimized_listener(self, mocker): + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k3', 'f3', 'on', 'l1', 123, None, utc_now-1, None, None), False), None) + ]) + assert for_counter == [ + Impression('k3', 'f3', 'on', 'l1', 123, None, utc_now-1, utc_now-1, None) + ] + assert deduped == 1 + assert for_unique_keys_tracker == [] + + def test_standalone_debug_listener(self, mocker): """Test impressions manager in optimized mode with sdk in standalone mode.""" # Mock utc_time function to be able to play with the clock utc_now = truncate_time(utctime_ms_reimplement()) + 1800 * 1000 utc_time_mock = mocker.Mock() utc_time_mock.return_value = utc_now - mocker.patch('splitio.util.utctime_ms', new=utc_time_mock) + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) + imps = [] listener = mocker.Mock(spec=ImpressionListenerWrapper) - - manager = Manager(listener=listener) # no listener - assert manager._counter is not None - assert manager._observer is not None - assert manager._listener is not None + manager = Manager(StrategyDebugMode(), StrategyNoneMode(), mocker.Mock()) + assert isinstance(manager._strategy, StrategyDebugMode) # An impression that hasn't happened in the last hour (pt = None) should be tracked - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), None), - (Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), False), None), + (ImpressionDecorated(Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), - Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3)] + assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None)] - # Tracking the same impression a ms later should return empty - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2), None) + assert listen == [(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), None), + (Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None), None)] + + # Tracking the same impression a ms later should return the imp + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), False), None) ]) - assert imps == [] + assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2, utc_now-3, None)] + assert listen == [(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2, utc_now-3, None), None)] + assert for_counter == [] + assert for_unique_keys_tracker == [] # Tracking a in impression with a different key makes it to the queue - imps = manager.process_impressions([ - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), False), None) ]) - assert imps == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1)] + assert imps == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None)] + assert listen == [(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), None)] + assert for_counter == [] + assert for_unique_keys_tracker == [] # Advance the perceived clock one hour old_utc = utc_now # save it to compare captured impressions utc_now += 3600 * 1000 utc_time_mock.return_value = utc_now + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) # Track the same impressions but "one hour later" - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1), None), - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2), None) - ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, old_utc-3), - Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1)] - - assert len(manager._observer._cache._data) == 3 # distinct impressions seen - assert len(manager._counter._data) == 3 # 2 distinct features. 1 seen in 2 different timeframes - - assert set(manager._counter.pop_all()) == set([ - Counter.CountPerFeature('f1', truncate_time(old_utc), 3), - Counter.CountPerFeature('f2', truncate_time(old_utc), 1), - Counter.CountPerFeature('f1', truncate_time(utc_now), 2) - ]) - - assert listener.log_impression.mock_calls == [ - mocker.call(Impression('k1', 'f1', 'on', 'l1', 123, None, old_utc-3), None), - mocker.call(Impression('k1', 'f2', 'on', 'l1', 123, None, old_utc-3), None), - mocker.call(Impression('k1', 'f1', 'on', 'l1', 123, None, old_utc-2, old_utc-3), None), - mocker.call(Impression('k2', 'f1', 'on', 'l1', 123, None, old_utc-1), None), - mocker.call(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, old_utc-3), None), - mocker.call(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), False), None), + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), False), None) + ]) + assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, old_utc-3, None), + Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1, None)] + assert listen == [ + (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, old_utc-3, None), None), + (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1, None), None) ] + assert len(manager._strategy._observer._cache._data) == 3 # distinct impressions seen + assert for_counter == [] + assert for_unique_keys_tracker == [] - def test_standalone_debug_listener(self, mocker): - """Test impressions manager in optimized mode with sdk in standalone mode.""" + def test_standalone_none_listener(self, mocker): + """Test impressions manager in none mode with sdk in standalone mode.""" # Mock utc_time function to be able to play with the clock utc_now = truncate_time(utctime_ms_reimplement()) + 1800 * 1000 utc_time_mock = mocker.Mock() utc_time_mock.return_value = utc_now - mocker.patch('splitio.util.utctime_ms', new=utc_time_mock) + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) - imps = [] - listener = mocker.Mock(spec=ImpressionListenerWrapper) - manager = Manager(ImpressionsMode.DEBUG, listener=listener) - assert manager._counter is None - assert manager._observer is not None - assert manager._listener is not None + manager = Manager(StrategyNoneMode(), StrategyNoneMode(), mocker.Mock()) + assert isinstance(manager._strategy, StrategyNoneMode) - # An impression that hasn't happened in the last hour (pt = None) should be tracked - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), None), - (Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3), None) + # An impression that hasn't happened in the last hour (pt = None) should not be tracked + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), False), None), + (ImpressionDecorated(Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), - Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3)] + assert imps == [] + assert listen == [(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), None), + (Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None), None)] - # Tracking the same impression a ms later should return the imp - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2), None) + assert for_counter == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None)] + assert for_unique_keys_tracker == [('k1', 'f1'), ('k1', 'f2')] + + # Tracking the same impression a ms later should return empty, no updates on mtk + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2, utc_now-3)] + assert imps == [] + assert listen == [(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), None)] + assert for_counter == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2, None)] + assert for_unique_keys_tracker == [('k1', 'f1')] - # Tracking a in impression with a different key makes it to the queue - imps = manager.process_impressions([ - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1), None) + # Tracking a in impression with a different key update mtk + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), False), None) ]) - assert imps == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1)] + assert imps == [] + assert listen == [(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), None)] + assert for_counter == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None)] + assert for_unique_keys_tracker == [('k2', 'f1')] # Advance the perceived clock one hour old_utc = utc_now # save it to compare captured impressions utc_now += 3600 * 1000 utc_time_mock.return_value = utc_now + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) # Track the same impressions but "one hour later" - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1), None), - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2), None) - ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, old_utc-3), - Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1)] - - assert len(manager._observer._cache._data) == 3 # distinct impressions seen - - assert listener.log_impression.mock_calls == [ - mocker.call(Impression('k1', 'f1', 'on', 'l1', 123, None, old_utc-3), None), - mocker.call(Impression('k1', 'f2', 'on', 'l1', 123, None, old_utc-3), None), - mocker.call(Impression('k1', 'f1', 'on', 'l1', 123, None, old_utc-2, old_utc-3), None), - mocker.call(Impression('k2', 'f1', 'on', 'l1', 123, None, old_utc-1), None), - mocker.call(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, old_utc-3), None), - mocker.call(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), False), None), + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), False), None) + ]) + assert imps == [] + assert for_counter == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), + Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None)] + assert listen == [ + (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), None), + (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), None) ] + assert for_unique_keys_tracker == [('k1', 'f1'), ('k2', 'f1')] - def test_non_standalone_optimized_listener(self, mocker): + def test_impression_toggle_optimized(self, mocker): """Test impressions manager in optimized mode with sdk in standalone mode.""" # Mock utc_time function to be able to play with the clock utc_now = truncate_time(utctime_ms_reimplement()) + 1800 * 1000 utc_time_mock = mocker.Mock() utc_time_mock.return_value = utc_now - mocker.patch('splitio.util.utctime_ms', new=utc_time_mock) + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() - imps = [] - listener = mocker.Mock(spec=ImpressionListenerWrapper) - manager = Manager(ImpressionsMode.OPTIMIZED, False, listener) # no listener - assert manager._counter is None - assert manager._observer is None - assert manager._listener is not None + manager = Manager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + assert manager._strategy._observer is not None + assert isinstance(manager._strategy, StrategyOptimizedMode) + assert isinstance(manager._none_strategy, StrategyNoneMode) # An impression that hasn't happened in the last hour (pt = None) should be tracked - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), None), - (Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), True), None), + (ImpressionDecorated(Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), - Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3)] - # Tracking the same impression a ms later should return the imp - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2), None) + assert for_unique_keys_tracker == [('k1', 'f1')] + assert imps == [Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None)] + assert deduped == 1 + + def test_impression_toggle_debug(self, mocker): + """Test impressions manager in optimized mode with sdk in standalone mode.""" + + # Mock utc_time function to be able to play with the clock + utc_now = truncate_time(utctime_ms_reimplement()) + 1800 * 1000 + utc_time_mock = mocker.Mock() + utc_time_mock.return_value = utc_now + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + + manager = Manager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + assert manager._strategy._observer is not None + + # An impression that hasn't happened in the last hour (pt = None) should be tracked + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), True), None), + (ImpressionDecorated(Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2)] - # Tracking a in impression with a different key makes it to the queue - imps = manager.process_impressions([ - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1), None) + assert for_unique_keys_tracker == [('k1', 'f1')] + assert imps == [Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None)] + assert deduped == 1 + + def test_impression_toggle_none(self, mocker): + """Test impressions manager in optimized mode with sdk in standalone mode.""" + + # Mock utc_time function to be able to play with the clock + utc_now = truncate_time(utctime_ms_reimplement()) + 1800 * 1000 + utc_time_mock = mocker.Mock() + utc_time_mock.return_value = utc_now + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + + strategy = StrategyNoneMode() + manager = Manager(strategy, strategy, telemetry_runtime_producer) # no listener + + # An impression that hasn't happened in the last hour (pt = None) should be tracked + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), True), None), + (ImpressionDecorated(Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None), False), None) ]) - assert imps == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1)] + + assert for_unique_keys_tracker == [('k1', 'f1'), ('k1', 'f2')] + assert imps == [] + assert deduped == 2 + + def test_impressions_properties_optimized(self, mocker): + """Test impressions manager in optimized mode with impressions properties.""" + + # Mock utc_time function to be able to play with the clock + utc_now = truncate_time(utctime_ms_reimplement()) + 1800 * 1000 + utc_time_mock = mocker.Mock() + utc_time_mock.return_value = utc_now + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + + manager = Manager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + assert manager._strategy._observer is not None + assert isinstance(manager._strategy, StrategyOptimizedMode) + assert isinstance(manager._none_strategy, StrategyNoneMode) + + # An impression that hasn't happened in the last hour (pt = None) should be tracked + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), False), None), + (ImpressionDecorated(Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None), False), None) + ]) + + assert for_unique_keys_tracker == [] + assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None)] + assert deduped == 0 + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), False), None) + ]) + assert imps == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None)] # Advance the perceived clock one hour old_utc = utc_now # save it to compare captured impressions utc_now += 3600 * 1000 utc_time_mock.return_value = utc_now + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) # Track the same impressions but "one hour later" - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1), None), - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2), None) - ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1), - Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2)] - - assert listener.log_impression.mock_calls == [ - mocker.call(Impression('k1', 'f1', 'on', 'l1', 123, None, old_utc-3), None), - mocker.call(Impression('k1', 'f2', 'on', 'l1', 123, None, old_utc-3), None), - mocker.call(Impression('k1', 'f1', 'on', 'l1', 123, None, old_utc-2), None), - mocker.call(Impression('k2', 'f1', 'on', 'l1', 123, None, old_utc-1), None), - mocker.call(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1), None), - mocker.call(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2), None) - ] - - def test_non_standalone_debug_listener(self, mocker): - """Test impressions manager in optimized mode with sdk in standalone mode.""" + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, None, {'prop': 'value'}), False), None), + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), False), None) + ]) + assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, None, {'prop': 'value'}), + Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1, None)] + assert deduped == 0 + assert for_unique_keys_tracker == [] + + def test_impressions_properties_debug(self, mocker): + """Test impressions manager in optimized mode with impressions properties.""" # Mock utc_time function to be able to play with the clock utc_now = truncate_time(utctime_ms_reimplement()) + 1800 * 1000 utc_time_mock = mocker.Mock() utc_time_mock.return_value = utc_now - mocker.patch('splitio.util.utctime_ms', new=utc_time_mock) + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() - listener = mocker.Mock(spec=ImpressionListenerWrapper) - manager = Manager(ImpressionsMode.DEBUG, False, listener) # no listener - assert manager._counter is None - assert manager._observer is None - assert manager._listener is not None + manager = Manager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener # An impression that hasn't happened in the last hour (pt = None) should be tracked - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), None), - (Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3), None) - ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3), - Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3)] - - # Tracking the same impression a ms later should return the imp - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2), None) + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), False), None), + (ImpressionDecorated(Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None), False), None) ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-2)] - # Tracking a in impression with a different key makes it to the queue - imps = manager.process_impressions([ - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1), None) + assert for_unique_keys_tracker == [] + assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-3, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, utc_now-3, None, None)] + assert deduped == 0 + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None), False), None) ]) - assert imps == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1)] + assert imps == [Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-1, None, None)] # Advance the perceived clock one hour old_utc = utc_now # save it to compare captured impressions utc_now += 3600 * 1000 utc_time_mock.return_value = utc_now + mocker.patch('splitio.engine.impressions.strategies.utctime_ms', return_value=utc_time_mock()) # Track the same impressions but "one hour later" - imps = manager.process_impressions([ - (Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1), None), - (Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2), None) - ]) - assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1), - Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2)] - - assert listener.log_impression.mock_calls == [ - mocker.call(Impression('k1', 'f1', 'on', 'l1', 123, None, old_utc-3), None), - mocker.call(Impression('k1', 'f2', 'on', 'l1', 123, None, old_utc-3), None), - mocker.call(Impression('k1', 'f1', 'on', 'l1', 123, None, old_utc-2), None), - mocker.call(Impression('k2', 'f1', 'on', 'l1', 123, None, old_utc-1), None), - mocker.call(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1), None), - mocker.call(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2), None) - ] + imps, deduped, listen, for_counter, for_unique_keys_tracker = manager.process_impressions([ + (ImpressionDecorated(Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, None, {'prop': 'value'}), False), None), + (ImpressionDecorated(Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, None, None), False), None) + ]) + assert imps == [Impression('k1', 'f1', 'on', 'l1', 123, None, utc_now-1, None, {'prop': 'value'}), + Impression('k2', 'f1', 'on', 'l1', 123, None, utc_now-2, old_utc-1, None)] + assert deduped == 0 + assert for_unique_keys_tracker == [] \ No newline at end of file diff --git a/tests/engine/test_send_adapters.py b/tests/engine/test_send_adapters.py new file mode 100644 index 00000000..97a17531 --- /dev/null +++ b/tests/engine/test_send_adapters.py @@ -0,0 +1,289 @@ +import unittest.mock as mock +import ast +import json +import pytest +import redis.asyncio as aioredis + +from splitio.engine.impressions.adapters import InMemorySenderAdapter, RedisSenderAdapter, PluggableSenderAdapter, \ + InMemorySenderAdapterAsync, RedisSenderAdapterAsync, PluggableSenderAdapterAsync +from splitio.engine.impressions import adapters +from splitio.api.telemetry import TelemetryAPI, TelemetryAPIAsync +from splitio.storage.adapters.redis import RedisAdapter, RedisAdapterAsync +from splitio.engine.impressions.manager import Counter +from tests.storage.test_pluggable import StorageMockAdapter, StorageMockAdapterAsync + + +class InMemorySenderAdapterTests(object): + """In memory sender adapter test.""" + + def test_uniques_formatter(self, mocker): + """Test formatting dict to json.""" + + uniques = {"feature1": set({'key1', 'key2', 'key3'}), + "feature2": set({'key6', 'key1', 'key10'}), + } + formatted = [ + {'f': 'feature1', 'ks': ['key1', 'key2', 'key3']}, + {'f': 'feature2', 'ks': ['key1', 'key6', 'key10']}, + ] + + sender_adapter = InMemorySenderAdapter(mocker.Mock()) + for i in range(0,1): + assert(sorted(sender_adapter._uniques_formatter(uniques)[i]["ks"]) == sorted(formatted[i]["ks"])) + + + @mock.patch('splitio.api.telemetry.TelemetryAPI.record_unique_keys') + def test_record_unique_keys(self, mocker): + """Test sending unique keys.""" + + uniques = {"feature1": set({'key1', 'key2', 'key3'}), + "feature2": set({'key1', 'key2', 'key3'}), + } + telemetry_api = TelemetryAPI(mocker.Mock(), 'some_api_key', mocker.Mock(), mocker.Mock()) + sender_adapter = InMemorySenderAdapter(telemetry_api) + sender_adapter.record_unique_keys(uniques) + assert(mocker.called) + +class InMemorySenderAdapterAsyncTests(object): + """In memory sender adapter test.""" + + @pytest.mark.asyncio + async def test_record_unique_keys(self, mocker): + """Test sending unique keys.""" + + uniques = {"feature1": set({'key1', 'key2', 'key3'}), + "feature2": set({'key1', 'key2', 'key3'}), + } + telemetry_api = TelemetryAPIAsync(mocker.Mock(), 'some_api_key', mocker.Mock(), mocker.Mock()) + self.called = False + async def record_unique_keys(*args): + self.called = True + + telemetry_api.record_unique_keys = record_unique_keys + sender_adapter = InMemorySenderAdapterAsync(telemetry_api) + await sender_adapter.record_unique_keys(uniques) + assert(self.called) + +class RedisSenderAdapterTests(object): + """Redis sender adapter test.""" + + def test_uniques_formatter(self, mocker): + """Test formatting dict to json.""" + + uniques = {"feature1": set({'key1', 'key2', 'key3'}), + "feature2": set({'key6', 'key1', 'key10'}), + } + formatted = [ + {'f': 'feature1', 'ks': ['key1', 'key2', 'key3']}, + {'f': 'feature2', 'ks': ['key6', 'key1', 'key10']}, + ] + + for i in range(0,1): + assert(sorted(ast.literal_eval(adapters._uniques_formatter(uniques)[i])["ks"]) == sorted(formatted[i]["ks"])) + + @mock.patch('splitio.storage.adapters.redis.RedisAdapter.rpush') + def test_record_unique_keys(self, mocker): + """Test sending unique keys.""" + + uniques = {"feature1": set({'key1', 'key2', 'key3'}), + "feature2": set({'key1', 'key2', 'key3'}), + } + redis_client = RedisAdapter(mocker.Mock(), mocker.Mock()) + sender_adapter = RedisSenderAdapter(redis_client) + sender_adapter.record_unique_keys(uniques) + assert(mocker.called) + + mocker.reset_mock() + sender_adapter.record_unique_keys({}) + assert(not mocker.called) + + @mock.patch('splitio.storage.adapters.redis.RedisPipelineAdapter.hincrby') + def test_flush_counters(self, mocker): + """Test sending counters.""" + + counters = [ + Counter.CountPerFeature('f1', 123, 2), + Counter.CountPerFeature('f2', 123, 123), + ] + redis_client = RedisAdapter(mocker.Mock(), mocker.Mock()) + sender_adapter = RedisSenderAdapter(redis_client) + sender_adapter.flush_counters(counters) + assert(mocker.called) + + mocker.reset_mock() + sender_adapter.flush_counters({}) + assert(not mocker.called) + + @mock.patch('splitio.storage.adapters.redis.RedisAdapter.expire') + def test_expire_keys(self, mocker): + """Test set expire key.""" + + total_keys = 100 + inserted = 10 + redis_client = RedisAdapter(mocker.Mock(), mocker.Mock()) + sender_adapter = RedisSenderAdapter(redis_client) + sender_adapter._expire_keys(mocker.Mock(), mocker.Mock(), total_keys, inserted) + assert(not mocker.called) + + total_keys = 100 + inserted = 100 + sender_adapter._expire_keys(mocker.Mock(), mocker.Mock(), total_keys, inserted) + assert(mocker.called) + + +class RedisSenderAdapterAsyncTests(object): + """Redis sender adapter test.""" + + @pytest.mark.asyncio + async def test_record_unique_keys(self, mocker): + """Test sending unique keys.""" + + uniques = {"feature1": set({'key1', 'key2', 'key3'}), + "feature2": set({'key1', 'key2', 'key3'}), + } + redis_client = RedisAdapterAsync(mocker.Mock(), mocker.Mock()) + sender_adapter = RedisSenderAdapterAsync(redis_client) + + self.called = False + async def rpush(*args): + self.called = True + + redis_client.rpush = rpush + await sender_adapter.record_unique_keys(uniques) + assert(self.called) + + @pytest.mark.asyncio + async def test_flush_counters(self, mocker): + """Test sending counters.""" + + counters = [ + Counter.CountPerFeature('f1', 123, 2), + Counter.CountPerFeature('f2', 123, 123), + ] + redis_client = await aioredis.from_url("redis://localhost") + sender_adapter = RedisSenderAdapterAsync(redis_client) + self.called = False + def hincrby(*args): + self.called = True + self.called2 = False + async def execute(*args): + self.called2 = True + return [1] + + with mock.patch('redis.asyncio.client.Pipeline.hincrby', hincrby): + with mock.patch('redis.asyncio.client.Pipeline.execute', execute): + await sender_adapter.flush_counters(counters) + assert(self.called) + assert(self.called2) + + @pytest.mark.asyncio + async def test_expire_keys(self, mocker): + """Test set expire key.""" + + total_keys = 100 + inserted = 10 + redis_client = RedisAdapterAsync(mocker.Mock(), mocker.Mock()) + sender_adapter = RedisSenderAdapterAsync(redis_client) + self.called = False + async def expire(*args): + self.called = True + redis_client.expire = expire + + await sender_adapter._expire_keys(mocker.Mock(), mocker.Mock(), total_keys, inserted) + assert(not self.called) + + total_keys = 100 + inserted = 100 + await sender_adapter._expire_keys(mocker.Mock(), mocker.Mock(), total_keys, inserted) + assert(self.called) + + +class PluggableSenderAdapterTests(object): + """Pluggable sender adapter test.""" + + def test_record_unique_keys(self, mocker): + """Test sending unique keys.""" + adapter = StorageMockAdapter() + sender_adapter = PluggableSenderAdapter(adapter) + + uniques = {"feature1": set({"key1", "key2", "key3"}), + "feature2": set({"key1", "key6", "key10"}), + } + formatted = [ + '{"f": "feature1", "ks": ["key3", "key2", "key1"]}', + '{"f": "feature2", "ks": ["key1", "key10", "key6"]}', + ] + + sender_adapter.record_unique_keys(uniques) + assert(sorted(json.loads(adapter._keys[adapters._MTK_QUEUE_KEY][0])["ks"]) == sorted(json.loads(formatted[0])["ks"])) + assert(sorted(json.loads(adapter._keys[adapters._MTK_QUEUE_KEY][1])["ks"]) == sorted(json.loads(formatted[1])["ks"])) + assert(json.loads(adapter._keys[adapters._MTK_QUEUE_KEY][0])["f"] == "feature1") + assert(json.loads(adapter._keys[adapters._MTK_QUEUE_KEY][1])["f"] == "feature2") + assert(adapter._expire[adapters._MTK_QUEUE_KEY] == adapters._MTK_KEY_DEFAULT_TTL) + sender_adapter.record_unique_keys(uniques) + assert(adapter._expire[adapters._MTK_QUEUE_KEY] != -1) + + adapter._keys[adapters._MTK_QUEUE_KEY] = {} + sender_adapter.record_unique_keys({}) + assert(adapter._keys[adapters._MTK_QUEUE_KEY] == {}) + + def test_flush_counters(self, mocker): + """Test sending counters.""" + adapter = StorageMockAdapter() + sender_adapter = PluggableSenderAdapter(adapter) + + counters = [ + Counter.CountPerFeature('f1', 123, 2), + Counter.CountPerFeature('f2', 123, 123), + ] + + sender_adapter.flush_counters(counters) + assert(adapter._keys[adapters._IMP_COUNT_QUEUE_KEY + "." + 'f1::123'] == 2) + assert(adapter._keys[adapters._IMP_COUNT_QUEUE_KEY + "." + 'f2::123'] == 123) + assert(adapter._expire[adapters._IMP_COUNT_QUEUE_KEY + "." + 'f1::123'] == adapters._IMP_COUNT_KEY_DEFAULT_TTL) + sender_adapter.flush_counters(counters) + assert(adapter._expire[adapters._IMP_COUNT_QUEUE_KEY + "." + 'f2::123'] == adapters._IMP_COUNT_KEY_DEFAULT_TTL) + +class PluggableSenderAdapterAsyncTests(object): + """Pluggable sender adapter test.""" + + @pytest.mark.asyncio + async def test_record_unique_keys(self, mocker): + """Test sending unique keys.""" + adapter = StorageMockAdapterAsync() + sender_adapter = PluggableSenderAdapterAsync(adapter) + + uniques = {"feature1": set({"key1", "key2", "key3"}), + "feature2": set({"key1", "key6", "key10"}), + } + formatted = [ + '{"f": "feature1", "ks": ["key3", "key2", "key1"]}', + '{"f": "feature2", "ks": ["key1", "key10", "key6"]}', + ] + + await sender_adapter.record_unique_keys(uniques) + assert(sorted(json.loads(adapter._keys[adapters._MTK_QUEUE_KEY][0])["ks"]) == sorted(json.loads(formatted[0])["ks"])) + assert(sorted(json.loads(adapter._keys[adapters._MTK_QUEUE_KEY][1])["ks"]) == sorted(json.loads(formatted[1])["ks"])) + assert(json.loads(adapter._keys[adapters._MTK_QUEUE_KEY][0])["f"] == "feature1") + assert(json.loads(adapter._keys[adapters._MTK_QUEUE_KEY][1])["f"] == "feature2") + assert(adapter._expire[adapters._MTK_QUEUE_KEY] == adapters._MTK_KEY_DEFAULT_TTL) + await sender_adapter.record_unique_keys(uniques) + assert(adapter._expire[adapters._MTK_QUEUE_KEY] != -1) + + @pytest.mark.asyncio + async def test_flush_counters(self, mocker): + """Test sending counters.""" + adapter = StorageMockAdapterAsync() + sender_adapter = PluggableSenderAdapterAsync(adapter) + + counters = [ + Counter.CountPerFeature('f1', 123, 2), + Counter.CountPerFeature('f2', 123, 123), + ] + + await sender_adapter.flush_counters(counters) + assert(adapter._keys[adapters._IMP_COUNT_QUEUE_KEY + "." + 'f1::123'] == 2) + assert(adapter._keys[adapters._IMP_COUNT_QUEUE_KEY + "." + 'f2::123'] == 123) + assert(adapter._expire[adapters._IMP_COUNT_QUEUE_KEY + "." + 'f1::123'] == adapters._IMP_COUNT_KEY_DEFAULT_TTL) + await sender_adapter.flush_counters(counters) + assert(adapter._expire[adapters._IMP_COUNT_QUEUE_KEY + "." + 'f2::123'] == adapters._IMP_COUNT_KEY_DEFAULT_TTL) \ No newline at end of file diff --git a/tests/engine/test_telemetry.py b/tests/engine/test_telemetry.py new file mode 100644 index 00000000..f4b669ea --- /dev/null +++ b/tests/engine/test_telemetry.py @@ -0,0 +1,797 @@ +import unittest.mock as mock +import pytest + +from splitio.engine.telemetry import TelemetryEvaluationConsumer, TelemetryEvaluationProducer, TelemetryInitConsumer, \ + TelemetryInitProducer, TelemetryRuntimeConsumer, TelemetryRuntimeProducer, TelemetryStorageConsumer, TelemetryStorageProducer, \ + TelemetryEvaluationConsumerAsync, TelemetryEvaluationProducerAsync, TelemetryInitConsumerAsync, \ + TelemetryInitProducerAsync, TelemetryRuntimeConsumerAsync, TelemetryRuntimeProducerAsync, TelemetryStorageConsumerAsync, TelemetryStorageProducerAsync +from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemoryTelemetryStorageAsync + +class TelemetryStorageProducerTests(object): + """TelemetryStorageProducer test.""" + + def test_instances(self): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + + assert(isinstance(telemetry_producer._telemetry_evaluation_producer, TelemetryEvaluationProducer)) + assert(isinstance(telemetry_producer._telemetry_init_producer, TelemetryInitProducer)) + assert(isinstance(telemetry_producer._telemetry_runtime_producer, TelemetryRuntimeProducer)) + + assert(telemetry_producer._telemetry_evaluation_producer == telemetry_producer.get_telemetry_evaluation_producer()) + assert(telemetry_producer._telemetry_init_producer == telemetry_producer.get_telemetry_init_producer()) + assert(telemetry_producer._telemetry_runtime_producer == telemetry_producer.get_telemetry_runtime_producer()) + + def test_record_config(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_init_producer = TelemetryInitProducer(telemetry_storage) + config = {'operationMode': 'standalone', + 'streamingEnabled': True, + 'impressionsQueueSize': 100, + 'eventsQueueSize': 200, + 'impressionsMode': 'DEBUG', + 'impressionListener': None, + 'featuresRefreshRate': 30, + 'segmentsRefreshRate': 30, + 'impressionsRefreshRate': 60, + 'eventsPushRate': 60, + 'metricsRefreshRate': 10, + 'storageType': None + } + telemetry_init_producer.record_config(config, {}, 5, 2) + telemetry_init_producer.record_active_and_redundant_factories(1, 0) + + assert(telemetry_storage._tel_config.get_stats() == {'oM': 0, + 'sT': telemetry_storage._tel_config._get_storage_type(config['operationMode'], config['storageType']), + 'sE': config['streamingEnabled'], + 'rR': {'sp': 30, 'se': 30, 'im': 60, 'ev': 60, 'te': 10}, + 'uO': {'s': False, 'e': False, 'a': False, 'st': False, 't': False}, + 'iQ': config['impressionsQueueSize'], + 'eQ': config['eventsQueueSize'], + 'iM': telemetry_storage._tel_config._get_impressions_mode(config['impressionsMode']), + 'iL': True if config['impressionListener'] is not None else False, + 'hp': telemetry_storage._tel_config._check_if_proxy_detected(), + 'bT': 0, + 'tR': 0, + 'nR': 0, + 'aF': 1, + 'rF': 0, + 'fsT': 5, + 'fsI': 2} + ) + + def test_record_ready_time(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_init_producer = TelemetryInitProducer(telemetry_storage) + + def record_ready_time(*args, **kwargs): + self.passed_arg = args[0] + + telemetry_storage.record_ready_time.side_effect = record_ready_time + telemetry_init_producer.record_ready_time(10) + assert(self.passed_arg == 10) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.record_bur_time_out') + def test_record_bur_timeout(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_init_producer = TelemetryInitProducer(telemetry_storage) + telemetry_init_producer.record_bur_time_out() + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.record_not_ready_usage') + def test_record_not_ready_usage(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_init_producer = TelemetryInitProducer(telemetry_storage) + telemetry_init_producer.record_not_ready_usage() + assert(mocker.called) + + def test_record_latency(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_evaluation_producer = TelemetryEvaluationProducer(telemetry_storage) + + def record_latency(*args, **kwargs): + self.passed_args = args + + telemetry_storage.record_latency.side_effect = record_latency + telemetry_evaluation_producer.record_latency('method', 10) + assert(self.passed_args[0] == 'method') + assert(self.passed_args[1] == 10) + + def test_record_exception(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_evaluation_producer = TelemetryEvaluationProducer(telemetry_storage) + + def record_exception(*args, **kwargs): + self.passed_method = args[0] + + telemetry_storage.record_exception.side_effect = record_exception + telemetry_evaluation_producer.record_exception('method') + assert(self.passed_method == 'method') + + def test_add_tag(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducer(telemetry_storage) + + def add_tag(*args, **kwargs): + self.passed_tag = args[0] + + telemetry_storage.add_tag.side_effect = add_tag + telemetry_runtime_producer.add_tag('tag') + assert(self.passed_tag == 'tag') + + def test_record_impression_stats(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducer(telemetry_storage) + + def record_impression_stats(*args, **kwargs): + self.passed_args = args + + telemetry_storage.record_impression_stats.side_effect = record_impression_stats + telemetry_runtime_producer.record_impression_stats('imp', 10) + assert(self.passed_args[0] == 'imp') + assert(self.passed_args[1] == 10) + + def test_record_event_stats(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducer(telemetry_storage) + + def record_event_stats(*args, **kwargs): + self.passed_args = args + + telemetry_storage.record_event_stats.side_effect = record_event_stats + telemetry_runtime_producer.record_event_stats('ev', 20) + assert(self.passed_args[0] == 'ev') + assert(self.passed_args[1] == 20) + + def test_record_successful_sync(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducer(telemetry_storage) + + def record_successful_sync(*args, **kwargs): + self.passed_args = args + + telemetry_storage.record_successful_sync.side_effect = record_successful_sync + telemetry_runtime_producer.record_successful_sync('split', 50) + assert(self.passed_args[0] == 'split') + assert(self.passed_args[1] == 50) + + def test_record_sync_error(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducer(telemetry_storage) + + def record_sync_error(*args, **kwargs): + self.passed_args = args + + telemetry_storage.record_sync_error.side_effect = record_sync_error + telemetry_runtime_producer.record_sync_error('segment', {'500': 1}) + assert(self.passed_args[0] == 'segment') + assert(self.passed_args[1] == {'500': 1}) + + def test_record_sync_latency(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducer(telemetry_storage) + + def record_sync_latency(*args, **kwargs): + self.passed_args = args + + telemetry_storage.record_sync_latency.side_effect = record_sync_latency + telemetry_runtime_producer.record_sync_latency('t', 40) + assert(self.passed_args[0] == 't') + assert(self.passed_args[1] == 40) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.record_auth_rejections') + def test_record_auth_rejections(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_runtime_producer = TelemetryRuntimeProducer(telemetry_storage) + telemetry_runtime_producer.record_auth_rejections() + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.record_token_refreshes') + def test_record_token_refreshes(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_runtime_producer = TelemetryRuntimeProducer(telemetry_storage) + telemetry_runtime_producer.record_token_refreshes() + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.record_update_from_sse') + def test_record_update_from_sse(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_runtime_producer = TelemetryRuntimeProducer(telemetry_storage) + telemetry_runtime_producer.record_update_from_sse('sp') + assert(mocker.called) + + def test_record_streaming_event(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducer(telemetry_storage) + + def record_streaming_event(*args, **kwargs): + self.passed_event = args[0] + + telemetry_storage.record_streaming_event.side_effect = record_streaming_event + telemetry_runtime_producer.record_streaming_event({'t', 40}) + assert(self.passed_event == {'t', 40}) + + def test_record_session_length(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducer(telemetry_storage) + + def record_session_length(*args, **kwargs): + self.passed_session = args[0] + + telemetry_storage.record_session_length.side_effect = record_session_length + telemetry_runtime_producer.record_session_length(30) + assert(self.passed_session == 30) + + +class TelemetryStorageProducerAsyncTests(object): + """TelemetryStorageProducer async test.""" + + @pytest.mark.asyncio + async def test_instances(self): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + + assert(isinstance(telemetry_producer._telemetry_evaluation_producer, TelemetryEvaluationProducerAsync)) + assert(isinstance(telemetry_producer._telemetry_init_producer, TelemetryInitProducerAsync)) + assert(isinstance(telemetry_producer._telemetry_runtime_producer, TelemetryRuntimeProducerAsync)) + + assert(telemetry_producer._telemetry_evaluation_producer == telemetry_producer.get_telemetry_evaluation_producer()) + assert(telemetry_producer._telemetry_init_producer == telemetry_producer.get_telemetry_init_producer()) + assert(telemetry_producer._telemetry_runtime_producer == telemetry_producer.get_telemetry_runtime_producer()) + + @pytest.mark.asyncio + async def test_record_config(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_init_producer = TelemetryInitProducerAsync(telemetry_storage) + + async def record_config(*args, **kwargs): + self.passed_config = args[0] + + telemetry_storage.record_config.side_effect = record_config + await telemetry_init_producer.record_config({'bT':0, 'nR':0, 'uC': 0}, {}) + assert(self.passed_config == {'bT':0, 'nR':0, 'uC': 0}) + + @pytest.mark.asyncio + async def test_record_ready_time(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_init_producer = TelemetryInitProducerAsync(telemetry_storage) + + async def record_ready_time(*args, **kwargs): + self.passed_arg = args[0] + + telemetry_storage.record_ready_time.side_effect = record_ready_time + await telemetry_init_producer.record_ready_time(10) + assert(self.passed_arg == 10) + + @pytest.mark.asyncio + async def test_record_bur_timeout(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def record_bur_time_out(*args): + self.called = True + telemetry_storage.record_bur_time_out = record_bur_time_out + + telemetry_init_producer = TelemetryInitProducerAsync(telemetry_storage) + await telemetry_init_producer.record_bur_time_out() + assert(self.called) + + @pytest.mark.asyncio + async def test_record_not_ready_usage(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def record_not_ready_usage(*args): + self.called = True + telemetry_storage.record_not_ready_usage = record_not_ready_usage + + telemetry_init_producer = TelemetryInitProducerAsync(telemetry_storage) + await telemetry_init_producer.record_not_ready_usage() + assert(self.called) + + @pytest.mark.asyncio + async def test_record_latency(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_evaluation_producer = TelemetryEvaluationProducerAsync(telemetry_storage) + + async def record_latency(*args, **kwargs): + self.passed_args = args + + telemetry_storage.record_latency.side_effect = record_latency + await telemetry_evaluation_producer.record_latency('method', 10) + assert(self.passed_args[0] == 'method') + assert(self.passed_args[1] == 10) + + @pytest.mark.asyncio + async def test_record_exception(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_evaluation_producer = TelemetryEvaluationProducerAsync(telemetry_storage) + + async def record_exception(*args, **kwargs): + self.passed_method = args[0] + + telemetry_storage.record_exception.side_effect = record_exception + await telemetry_evaluation_producer.record_exception('method') + assert(self.passed_method == 'method') + + @pytest.mark.asyncio + async def test_add_tag(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducerAsync(telemetry_storage) + + async def add_tag(*args, **kwargs): + self.passed_tag = args[0] + + telemetry_storage.add_tag.side_effect = add_tag + await telemetry_runtime_producer.add_tag('tag') + assert(self.passed_tag == 'tag') + + @pytest.mark.asyncio + async def test_record_impression_stats(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducerAsync(telemetry_storage) + + async def record_impression_stats(*args, **kwargs): + self.passed_args = args + + telemetry_storage.record_impression_stats.side_effect = record_impression_stats + await telemetry_runtime_producer.record_impression_stats('imp', 10) + assert(self.passed_args[0] == 'imp') + assert(self.passed_args[1] == 10) + + @pytest.mark.asyncio + async def test_record_event_stats(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducerAsync(telemetry_storage) + + async def record_event_stats(*args, **kwargs): + self.passed_args = args + + telemetry_storage.record_event_stats.side_effect = record_event_stats + await telemetry_runtime_producer.record_event_stats('ev', 20) + assert(self.passed_args[0] == 'ev') + assert(self.passed_args[1] == 20) + + @pytest.mark.asyncio + async def test_record_successful_sync(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducerAsync(telemetry_storage) + + async def record_successful_sync(*args, **kwargs): + self.passed_args = args + + telemetry_storage.record_successful_sync.side_effect = record_successful_sync + await telemetry_runtime_producer.record_successful_sync('split', 50) + assert(self.passed_args[0] == 'split') + assert(self.passed_args[1] == 50) + + @pytest.mark.asyncio + async def test_record_sync_error(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducerAsync(telemetry_storage) + + async def record_sync_error(*args, **kwargs): + self.passed_args = args + + telemetry_storage.record_sync_error.side_effect = record_sync_error + await telemetry_runtime_producer.record_sync_error('segment', {'500': 1}) + assert(self.passed_args[0] == 'segment') + assert(self.passed_args[1] == {'500': 1}) + + @pytest.mark.asyncio + async def test_record_sync_latency(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducerAsync(telemetry_storage) + + async def record_sync_latency(*args, **kwargs): + self.passed_args = args + + telemetry_storage.record_sync_latency.side_effect = record_sync_latency + await telemetry_runtime_producer.record_sync_latency('t', 40) + assert(self.passed_args[0] == 't') + assert(self.passed_args[1] == 40) + + @pytest.mark.asyncio + async def test_record_auth_rejections(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def record_auth_rejections(*args): + self.called = True + telemetry_storage.record_auth_rejections = record_auth_rejections + telemetry_runtime_producer = TelemetryRuntimeProducerAsync(telemetry_storage) + await telemetry_runtime_producer.record_auth_rejections() + assert(self.called) + + @pytest.mark.asyncio + async def test_record_token_refreshes(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def record_token_refreshes(*args): + self.called = True + telemetry_storage.record_token_refreshes = record_token_refreshes + telemetry_runtime_producer = TelemetryRuntimeProducerAsync(telemetry_storage) + await telemetry_runtime_producer.record_token_refreshes() + assert(self.called) + + @pytest.mark.asyncio + async def test_record_update_from_sse(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def record_update_from_sse(*args): + self.called = True + telemetry_storage.record_update_from_sse = record_update_from_sse + telemetry_runtime_producer = TelemetryRuntimeProducerAsync(telemetry_storage) + await telemetry_runtime_producer.record_update_from_sse('sp') + assert(self.called) + + @pytest.mark.asyncio + async def test_record_streaming_event(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducerAsync(telemetry_storage) + + async def record_streaming_event(*args, **kwargs): + self.passed_event = args[0] + + telemetry_storage.record_streaming_event.side_effect = record_streaming_event + await telemetry_runtime_producer.record_streaming_event({'t', 40}) + assert(self.passed_event == {'t', 40}) + + @pytest.mark.asyncio + async def test_record_session_length(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_producer = TelemetryRuntimeProducerAsync(telemetry_storage) + + async def record_session_length(*args, **kwargs): + self.passed_session = args[0] + + telemetry_storage.record_session_length.side_effect = record_session_length + await telemetry_runtime_producer.record_session_length(30) + assert(self.passed_session == 30) + + +class TelemetryStorageConsumerTests(object): + """TelemetryStorageConsumer test.""" + + def test_instances(self): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_consumer = TelemetryStorageConsumer(telemetry_storage) + + assert(isinstance(telemetry_consumer._telemetry_evaluation_consumer, TelemetryEvaluationConsumer)) + assert(isinstance(telemetry_consumer._telemetry_init_consumer, TelemetryInitConsumer)) + assert(isinstance(telemetry_consumer._telemetry_runtime_consumer, TelemetryRuntimeConsumer)) + + assert(telemetry_consumer._telemetry_evaluation_consumer == telemetry_consumer.get_telemetry_evaluation_consumer()) + assert(telemetry_consumer._telemetry_init_consumer == telemetry_consumer.get_telemetry_init_consumer()) + assert(telemetry_consumer._telemetry_runtime_consumer == telemetry_consumer.get_telemetry_runtime_consumer()) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.get_bur_time_outs') + def test_get_bur_time_outs(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_init_consumer = TelemetryInitConsumer(telemetry_storage) + telemetry_init_consumer.get_bur_time_outs() + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.get_not_ready_usage') + def get_not_ready_usage(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_init_consumer = TelemetryInitConsumer(telemetry_storage) + telemetry_init_consumer.get_not_ready_usage() + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.get_config_stats') + def get_not_ready_usage(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_init_consumer = TelemetryInitConsumer(telemetry_storage) + telemetry_init_consumer.get_config_stats() + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.pop_exceptions') + def pop_exceptions(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_evaluation_consumer = TelemetryEvaluationConsumer(telemetry_storage) + telemetry_evaluation_consumer.pop_exceptions() + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.pop_latencies') + def pop_latencies(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_evaluation_consumer = TelemetryEvaluationConsumer(telemetry_storage) + telemetry_evaluation_consumer.pop_latencies() + assert(mocker.called) + + def test_get_impressions_stats(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_consumer = TelemetryRuntimeConsumer(telemetry_storage) + + def get_impressions_stats(*args, **kwargs): + self.passed_type = args[0] + + telemetry_storage.get_impressions_stats.side_effect = get_impressions_stats + telemetry_runtime_consumer.get_impressions_stats('iQ') + assert(self.passed_type == 'iQ') + + def test_get_events_stats(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_consumer = TelemetryRuntimeConsumer(telemetry_storage) + + def get_events_stats(*args, **kwargs): + self.event_type = args[0] + + telemetry_storage.get_events_stats.side_effect = get_events_stats + telemetry_runtime_consumer.get_events_stats('eQ') + assert(self.event_type == 'eQ') + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.get_last_synchronization') + def test_get_last_synchronization(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_runtime_consumer = TelemetryRuntimeConsumer(telemetry_storage) + telemetry_runtime_consumer.get_last_synchronization() + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.pop_tags') + def test_pop_tags(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_runtime_consumer = TelemetryRuntimeConsumer(telemetry_storage) + telemetry_runtime_consumer.pop_tags() + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.pop_http_errors') + def test_pop_http_errors(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_runtime_consumer = TelemetryRuntimeConsumer(telemetry_storage) + telemetry_runtime_consumer.pop_http_errors() + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.pop_http_latencies') + def test_pop_http_latencies(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_runtime_consumer = TelemetryRuntimeConsumer(telemetry_storage) + telemetry_runtime_consumer.pop_http_latencies() + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.pop_auth_rejections') + def test_pop_auth_rejections(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_runtime_consumer = TelemetryRuntimeConsumer(telemetry_storage) + telemetry_runtime_consumer.pop_auth_rejections() + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.pop_update_from_sse') + def pop_update_from_sse(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_runtime_consumer = TelemetryRuntimeConsumer(telemetry_storage) + telemetry_runtime_consumer.pop_update_from_sse('sp') + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.pop_update_from_sse') + def test_pop_auth_rejections(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_runtime_consumer = TelemetryRuntimeConsumer(telemetry_storage) + telemetry_runtime_consumer.pop_update_from_sse('sp') + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.pop_token_refreshes') + def test_pop_token_refreshes(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_runtime_consumer = TelemetryRuntimeConsumer(telemetry_storage) + telemetry_runtime_consumer.pop_token_refreshes() + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.pop_streaming_events') + def test_pop_streaming_events(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_runtime_consumer = TelemetryRuntimeConsumer(telemetry_storage) + telemetry_runtime_consumer.pop_streaming_events() + assert(mocker.called) + + @mock.patch('splitio.storage.inmemmory.InMemoryTelemetryStorage.get_session_length') + def test_get_session_length(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_runtime_consumer = TelemetryRuntimeConsumer(telemetry_storage) + telemetry_runtime_consumer.get_session_length() + assert(mocker.called) + + +class TelemetryStorageConsumerAsyncTests(object): + """TelemetryStorageConsumer async test.""" + + @pytest.mark.asyncio + async def test_instances(self): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_consumer = TelemetryStorageConsumerAsync(telemetry_storage) + + assert(isinstance(telemetry_consumer._telemetry_evaluation_consumer, TelemetryEvaluationConsumerAsync)) + assert(isinstance(telemetry_consumer._telemetry_init_consumer, TelemetryInitConsumerAsync)) + assert(isinstance(telemetry_consumer._telemetry_runtime_consumer, TelemetryRuntimeConsumerAsync)) + + assert(telemetry_consumer._telemetry_evaluation_consumer == telemetry_consumer.get_telemetry_evaluation_consumer()) + assert(telemetry_consumer._telemetry_init_consumer == telemetry_consumer.get_telemetry_init_consumer()) + assert(telemetry_consumer._telemetry_runtime_consumer == telemetry_consumer.get_telemetry_runtime_consumer()) + + @pytest.mark.asyncio + async def test_get_bur_time_outs(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def get_bur_time_outs(*args): + self.called = True + telemetry_storage.get_bur_time_outs = get_bur_time_outs + + telemetry_init_consumer = TelemetryInitConsumerAsync(telemetry_storage) + await telemetry_init_consumer.get_bur_time_outs() + assert(self.called) + + @pytest.mark.asyncio + async def get_not_ready_usage(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def get_not_ready_usage(*args): + self.called = True + telemetry_storage.get_not_ready_usage = get_not_ready_usage + + telemetry_init_consumer = TelemetryInitConsumerAsync(telemetry_storage) + await telemetry_init_consumer.get_not_ready_usage() + assert(self.called) + + @pytest.mark.asyncio + async def get_not_ready_usage(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def get_config_stats(*args): + self.called = True + telemetry_storage.get_config_stats = get_config_stats + + telemetry_init_consumer = TelemetryInitConsumerAsync(telemetry_storage) + await telemetry_init_consumer.get_config_stats() + assert(mocker.called) + + @pytest.mark.asyncio + async def pop_exceptions(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def pop_exceptions(*args): + self.called = True + telemetry_storage.pop_exceptions = pop_exceptions + + telemetry_evaluation_consumer = TelemetryEvaluationConsumerAsync(telemetry_storage) + await telemetry_evaluation_consumer.pop_exceptions() + assert(mocker.called) + + @pytest.mark.asyncio + async def pop_latencies(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def pop_latencies(*args): + self.called = True + telemetry_storage.pop_latencies = pop_latencies + + telemetry_evaluation_consumer = TelemetryEvaluationConsumerAsync(telemetry_storage) + await telemetry_evaluation_consumer.pop_latencies() + assert(mocker.called) + + @pytest.mark.asyncio + async def test_get_impressions_stats(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_consumer = TelemetryRuntimeConsumerAsync(telemetry_storage) + + async def get_impressions_stats(*args, **kwargs): + self.passed_type = args[0] + + telemetry_storage.get_impressions_stats.side_effect = get_impressions_stats + await telemetry_runtime_consumer.get_impressions_stats('iQ') + assert(self.passed_type == 'iQ') + + @pytest.mark.asyncio + async def test_get_events_stats(self, mocker): + telemetry_storage = mocker.Mock() + telemetry_runtime_consumer = TelemetryRuntimeConsumerAsync(telemetry_storage) + + async def get_events_stats(*args, **kwargs): + self.event_type = args[0] + + telemetry_storage.get_events_stats.side_effect = get_events_stats + await telemetry_runtime_consumer.get_events_stats('eQ') + assert(self.event_type == 'eQ') + + @pytest.mark.asyncio + async def test_get_last_synchronization(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def get_last_synchronization(*args, **kwargs): + self.called = True + return {'lastSynchronizations': ""} + telemetry_storage.get_last_synchronization = get_last_synchronization + + telemetry_runtime_consumer = TelemetryRuntimeConsumerAsync(telemetry_storage) + await telemetry_runtime_consumer.get_last_synchronization() + assert(self.called) + + @pytest.mark.asyncio + async def test_pop_tags(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def pop_tags(*args, **kwargs): + self.called = True + telemetry_storage.pop_tags = pop_tags + telemetry_runtime_consumer = TelemetryRuntimeConsumerAsync(telemetry_storage) + await telemetry_runtime_consumer.pop_tags() + assert(self.called) + + @pytest.mark.asyncio + async def test_pop_http_errors(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def pop_http_errors(*args, **kwargs): + self.called = True + telemetry_storage.pop_http_errors = pop_http_errors + + telemetry_runtime_consumer = TelemetryRuntimeConsumerAsync(telemetry_storage) + await telemetry_runtime_consumer.pop_http_errors() + assert(self.called) + + @pytest.mark.asyncio + async def test_pop_http_latencies(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def pop_http_latencies(*args, **kwargs): + self.called = True + telemetry_storage.pop_http_latencies = pop_http_latencies + + telemetry_runtime_consumer = TelemetryRuntimeConsumerAsync(telemetry_storage) + await telemetry_runtime_consumer.pop_http_latencies() + assert(self.called) + + @pytest.mark.asyncio + async def test_pop_auth_rejections(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def pop_auth_rejections(*args, **kwargs): + self.called = True + telemetry_storage.pop_auth_rejections = pop_auth_rejections + + telemetry_runtime_consumer = TelemetryRuntimeConsumerAsync(telemetry_storage) + await telemetry_runtime_consumer.pop_auth_rejections() + assert(self.called) + + @pytest.mark.asyncio + async def pop_update_from_sse(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def pop_update_from_sse(*args, **kwargs): + self.called = True + telemetry_storage.pop_update_from_sse = pop_update_from_sse + telemetry_runtime_consumer = TelemetryRuntimeConsumerAsync(telemetry_storage) + await telemetry_runtime_consumer.pop_update_from_sse('sp') + assert(self.called) + + @pytest.mark.asyncio + async def test_pop_token_refreshes(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def pop_token_refreshes(*args, **kwargs): + self.called = True + telemetry_storage.pop_token_refreshes = pop_token_refreshes + + telemetry_runtime_consumer = TelemetryRuntimeConsumerAsync(telemetry_storage) + await telemetry_runtime_consumer.pop_token_refreshes() + assert(self.called) + + @pytest.mark.asyncio + async def test_pop_streaming_events(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def pop_streaming_events(*args, **kwargs): + self.called = True + telemetry_storage.pop_streaming_events = pop_streaming_events + + telemetry_runtime_consumer = TelemetryRuntimeConsumerAsync(telemetry_storage) + await telemetry_runtime_consumer.pop_streaming_events() + assert(self.called) + + @pytest.mark.asyncio + async def test_get_session_length(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + self.called = False + async def get_session_length(*args, **kwargs): + self.called = True + telemetry_storage.get_session_length = get_session_length + + telemetry_runtime_consumer = TelemetryRuntimeConsumerAsync(telemetry_storage) + await telemetry_runtime_consumer.get_session_length() + assert(self.called) diff --git a/tests/engine/test_unique_keys_tracker.py b/tests/engine/test_unique_keys_tracker.py new file mode 100644 index 00000000..93272f33 --- /dev/null +++ b/tests/engine/test_unique_keys_tracker.py @@ -0,0 +1,124 @@ +"""BloomFilter unit tests.""" +import pytest + +from splitio.engine.impressions.unique_keys_tracker import UniqueKeysTracker, UniqueKeysTrackerAsync +from splitio.engine.filters import BloomFilter + +class UniqueKeysTrackerTests(object): + """StandardRecorderTests test cases.""" + + def test_adding_and_removing_keys(self, mocker): + tracker = UniqueKeysTracker() + + assert(tracker._cache_size > 0) + assert(tracker._current_cache_size == 0) + assert(tracker._cache == {}) + assert(isinstance(tracker._filter, BloomFilter)) + + key1 = 'key1' + key2 = 'key2' + key3 = 'key3' + split1= 'feature1' + split2= 'feature2' + + assert(tracker.track(key1, split1)) + assert(tracker.track(key3, split1)) + assert(not tracker.track(key1, split1)) + assert(tracker.track(key2, split2)) + + assert(tracker._filter.contains(split1+key1)) + assert(not tracker._filter.contains(split1+key2)) + assert(tracker._filter.contains(split2+key2)) + assert(not tracker._filter.contains(split2+key1)) + assert(key1 in tracker._cache[split1]) + assert(key3 in tracker._cache[split1]) + assert(key2 in tracker._cache[split2]) + assert(not key3 in tracker._cache[split2]) + + tracker.clear_filter() + assert(not tracker._filter.contains(split1+key1)) + assert(not tracker._filter.contains(split2+key2)) + + cache_backup = tracker._cache.copy() + cache_size_backup = tracker._current_cache_size + cache, cache_size = tracker.get_cache_info_and_pop_all() + assert(cache_backup == cache) + assert(cache_size_backup == cache_size) + assert(tracker._current_cache_size == 0) + assert(tracker._cache == {}) + + def test_cache_size(self, mocker): + cache_size = 10 + tracker = UniqueKeysTracker(cache_size) + + split1= 'feature1' + for x in range(1, cache_size + 1): + tracker.track('key' + str(x), split1) + split2= 'feature2' + for x in range(1, int(cache_size / 2) + 1): + tracker.track('key' + str(x), split2) + + assert(tracker._current_cache_size == (cache_size + (cache_size / 2))) + assert(len(tracker._cache[split1]) == cache_size) + assert(len(tracker._cache[split2]) == cache_size / 2) + + +class UniqueKeysTrackerAsyncTests(object): + """StandardRecorderTests test cases.""" + + @pytest.mark.asyncio + async def test_adding_and_removing_keys(self, mocker): + tracker = UniqueKeysTrackerAsync() + + assert(tracker._cache_size > 0) + assert(tracker._current_cache_size == 0) + assert(tracker._cache == {}) + assert(isinstance(tracker._filter, BloomFilter)) + + key1 = 'key1' + key2 = 'key2' + key3 = 'key3' + split1= 'feature1' + split2= 'feature2' + + assert(await tracker.track(key1, split1)) + assert(await tracker.track(key3, split1)) + assert(not await tracker.track(key1, split1)) + assert(await tracker.track(key2, split2)) + + assert(tracker._filter.contains(split1+key1)) + assert(not tracker._filter.contains(split1+key2)) + assert(tracker._filter.contains(split2+key2)) + assert(not tracker._filter.contains(split2+key1)) + assert(key1 in tracker._cache[split1]) + assert(key3 in tracker._cache[split1]) + assert(key2 in tracker._cache[split2]) + assert(not key3 in tracker._cache[split2]) + + await tracker.clear_filter() + assert(not tracker._filter.contains(split1+key1)) + assert(not tracker._filter.contains(split2+key2)) + + cache_backup = tracker._cache.copy() + cache_size_backup = tracker._current_cache_size + cache, cache_size = await tracker.get_cache_info_and_pop_all() + assert(cache_backup == cache) + assert(cache_size_backup == cache_size) + assert(tracker._current_cache_size == 0) + assert(tracker._cache == {}) + + @pytest.mark.asyncio + async def test_cache_size(self, mocker): + cache_size = 10 + tracker = UniqueKeysTrackerAsync(cache_size) + + split1= 'feature1' + for x in range(1, cache_size + 1): + await tracker.track('key' + str(x), split1) + split2= 'feature2' + for x in range(1, int(cache_size / 2) + 1): + await tracker.track('key' + str(x), split2) + + assert(tracker._current_cache_size == (cache_size + (cache_size / 2))) + assert(len(tracker._cache[split1]) == cache_size) + assert(len(tracker._cache[split2]) == cache_size / 2) diff --git a/tests/events/test_events_delivery.py b/tests/events/test_events_delivery.py new file mode 100644 index 00000000..27076de4 --- /dev/null +++ b/tests/events/test_events_delivery.py @@ -0,0 +1,44 @@ +"""EventsManager test module.""" +import pytest + +from splitio.models.events import SdkEvent, SdkInternalEvent +from splitio.events.events_metadata import EventsMetadata +from splitio.events.events_delivery import EventsDelivery +from splitio.events.events_metadata import SdkEventType + +class EventsDeliveryTests(object): + """Tests for EventsManager.""" + + sdk_ready_flag = False + metadata = None + + def test_firing_events(self): + events_delivery = EventsDelivery() + + metadata = EventsMetadata(SdkEventType.FLAG_UPDATE, { "feature1" }) + events_delivery.deliver(SdkEvent.SDK_READY, metadata, self._sdk_ready_callback) + assert self.sdk_ready_flag + self._verify_metadata(metadata) + + @pytest.mark.asyncio + async def test_firing_events(self): + events_delivery = EventsDelivery() + + metadata = EventsMetadata(SdkEventType.FLAG_UPDATE, { "feature1" }) + self.sdk_ready_flag = False + self.metadata = None + await events_delivery.deliver_async(SdkEvent.SDK_READY, metadata, self._sdk_ready_callback_async) + assert self.sdk_ready_flag + self._verify_metadata(metadata) + + def _sdk_ready_callback(self, metadata): + self.sdk_ready_flag = True + self.metadata = metadata + + async def _sdk_ready_callback_async(self, metadata): + self.sdk_ready_flag = True + self.metadata = metadata + + def _verify_metadata(self, metadata): + assert metadata.get_type() == self.metadata.get_type() + assert metadata.get_names() == self.metadata.get_names() \ No newline at end of file diff --git a/tests/events/test_events_manager.py b/tests/events/test_events_manager.py new file mode 100644 index 00000000..6222b68b --- /dev/null +++ b/tests/events/test_events_manager.py @@ -0,0 +1,150 @@ +"""EventsManager test module.""" +import pytest +import asyncio + +from splitio.models.events import SdkEvent, SdkInternalEvent +from splitio.events.events_metadata import EventsMetadata +from splitio.events.events_manager_config import EventsManagerConfig +from splitio.events.events_delivery import EventsDelivery +from splitio.events.events_manager import EventsManager, EventsManagerAsync +from splitio.events.events_metadata import SdkEventType + +class EventsManagerTests(object): + """Tests for EventsManager.""" + + sdk_ready_flag = False + sdk_update_flag = False + metadata = None + + def test_firing_events(self): + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + events_manager.register(SdkEvent.SDK_READY, self._sdk_ready_callback) + events_manager.register(SdkEvent.SDK_UPDATE, self._sdk_update_callback) + + metadata = EventsMetadata(SdkEventType.FLAG_UPDATE, { "feature1" }) + events_manager.notify_internal_event(SdkInternalEvent.FLAGS_UPDATED, metadata) + events_manager.notify_internal_event(SdkInternalEvent.FLAG_KILLED_NOTIFICATION, metadata) + events_manager.notify_internal_event(SdkInternalEvent.RB_SEGMENTS_UPDATED, metadata) + events_manager.notify_internal_event(SdkInternalEvent.SEGMENTS_UPDATED, metadata) + assert not self.sdk_ready_flag + assert not self.sdk_update_flag + + self._reset_flags() + events_manager.notify_internal_event(SdkInternalEvent.SDK_READY, metadata) + assert self.sdk_ready_flag + assert not self.sdk_update_flag + self._verify_metadata(metadata) + + self._reset_flags() + events_manager.notify_internal_event(SdkInternalEvent.RB_SEGMENTS_UPDATED, metadata) + assert not self.sdk_ready_flag + assert self.sdk_update_flag + self._verify_metadata(metadata) + + self._reset_flags() + events_manager.notify_internal_event(SdkInternalEvent.FLAG_KILLED_NOTIFICATION, metadata) + assert not self.sdk_ready_flag + assert self.sdk_update_flag + self._verify_metadata(metadata) + + self._reset_flags() + events_manager.notify_internal_event(SdkInternalEvent.FLAGS_UPDATED, metadata) + assert not self.sdk_ready_flag + assert self.sdk_update_flag + self._verify_metadata(metadata) + + self._reset_flags() + events_manager.notify_internal_event(SdkInternalEvent.SEGMENTS_UPDATED, metadata) + assert not self.sdk_ready_flag + assert self.sdk_update_flag + self._verify_metadata(metadata) + + def _reset_flags(self): + self.sdk_ready_flag = False + self.sdk_update_flag = False + self.metadata = None + + def _sdk_ready_callback(self, metadata): + self.sdk_ready_flag = True + self.metadata = metadata + + def _sdk_update_callback(self, metadata): + self.sdk_update_flag = True + self.metadata = metadata + + def _verify_metadata(self, metadata): + assert metadata.get_type() == self.metadata.get_type() + assert metadata.get_names() == self.metadata.get_names() + +class EventsManagerAsyncTests(object): + """Tests for EventsManagerAsync.""" + + sdk_ready_flag = False + sdk_update_flag = False + metadata = None + + @pytest.mark.asyncio + async def test_firing_events(self): + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + await events_manager.register(SdkEvent.SDK_READY, self._sdk_ready_callback) + await events_manager.register(SdkEvent.SDK_UPDATE, self._sdk_update_callback) + + metadata = EventsMetadata(SdkEventType.FLAG_UPDATE, { "feature1" }) + await events_manager.notify_internal_event(SdkInternalEvent.FLAGS_UPDATED, metadata) + await events_manager.notify_internal_event(SdkInternalEvent.FLAG_KILLED_NOTIFICATION, metadata) + await events_manager.notify_internal_event(SdkInternalEvent.RB_SEGMENTS_UPDATED, metadata) + await events_manager.notify_internal_event(SdkInternalEvent.SEGMENTS_UPDATED, metadata) + assert not self.sdk_ready_flag + assert not self.sdk_update_flag + + self._reset_flags() + await events_manager.notify_internal_event(SdkInternalEvent.SDK_READY, metadata) + await asyncio.sleep(.3) + assert self.sdk_ready_flag + assert not self.sdk_update_flag + self._verify_metadata(metadata) + + self._reset_flags() + await events_manager.notify_internal_event(SdkInternalEvent.RB_SEGMENTS_UPDATED, metadata) + await asyncio.sleep(.3) + assert not self.sdk_ready_flag + assert self.sdk_update_flag + self._verify_metadata(metadata) + + self._reset_flags() + await events_manager.notify_internal_event(SdkInternalEvent.FLAG_KILLED_NOTIFICATION, metadata) + await asyncio.sleep(.3) + assert not self.sdk_ready_flag + assert self.sdk_update_flag + self._verify_metadata(metadata) + + self._reset_flags() + await events_manager.notify_internal_event(SdkInternalEvent.FLAGS_UPDATED, metadata) + await asyncio.sleep(.3) + assert not self.sdk_ready_flag + assert self.sdk_update_flag + self._verify_metadata(metadata) + + self._reset_flags() + await events_manager.notify_internal_event(SdkInternalEvent.SEGMENTS_UPDATED, metadata) + await asyncio.sleep(.3) + assert not self.sdk_ready_flag + assert self.sdk_update_flag + self._verify_metadata(metadata) + + def _reset_flags(self): + self.sdk_ready_flag = False + self.sdk_update_flag = False + self.metadata = None + + async def _sdk_ready_callback(self, metadata): + self.sdk_ready_flag = True + self.metadata = metadata + + async def _sdk_update_callback(self, metadata): + self.sdk_update_flag = True + self.metadata = metadata + + def _verify_metadata(self, metadata): + assert metadata.get_type() == self.metadata.get_type() + assert metadata.get_names() == self.metadata.get_names() \ No newline at end of file diff --git a/tests/events/test_events_manager_config.py b/tests/events/test_events_manager_config.py new file mode 100644 index 00000000..aa70c4d8 --- /dev/null +++ b/tests/events/test_events_manager_config.py @@ -0,0 +1,34 @@ +"""EventsManagerConfig test module.""" +import pytest + +from splitio.events.events_manager_config import EventsManagerConfig +from splitio.models.events import SdkEvent, SdkInternalEvent + +class EventsManagerConfigTests(object): + """Tests for EventsManagerConfig.""" + + def test_build_instance(self): + config = EventsManagerConfig() + + assert len(config.require_all[SdkEvent.SDK_READY]) == 1 + assert SdkInternalEvent.SDK_READY in config.require_all[SdkEvent.SDK_READY] + + assert SdkEvent.SDK_READY in config.prerequisites[SdkEvent.SDK_UPDATE] + + assert config.execution_limits[SdkEvent.SDK_UPDATE] == -1 + assert config.execution_limits[SdkEvent.SDK_READY] == 1 + + assert len(config.require_any[SdkEvent.SDK_UPDATE]) == 4 + assert SdkInternalEvent.FLAG_KILLED_NOTIFICATION in config.require_any[SdkEvent.SDK_UPDATE] + assert SdkInternalEvent.FLAGS_UPDATED in config.require_any[SdkEvent.SDK_UPDATE] + assert SdkInternalEvent.RB_SEGMENTS_UPDATED in config.require_any[SdkEvent.SDK_UPDATE] + assert SdkInternalEvent.SEGMENTS_UPDATED in config.require_any[SdkEvent.SDK_UPDATE] + + order = 0 + assert len(config.evaluation_order) == 2 + for sdk_event in config.evaluation_order: + order += 1 + if order == 1: + assert sdk_event == SdkEvent.SDK_READY + if order == 2: + assert sdk_event == SdkEvent.SDK_UPDATE \ No newline at end of file diff --git a/tests/events/test_events_metadata.py b/tests/events/test_events_metadata.py new file mode 100644 index 00000000..3ce90d0f --- /dev/null +++ b/tests/events/test_events_metadata.py @@ -0,0 +1,21 @@ +"""EventsMetadata test module.""" +import pytest + +from splitio.events.events_metadata import EventsMetadata +from splitio.events.events_metadata import SdkEventType + +class EventsMetadataTests(object): + """Tests for EventsMetadata.""" + + def test_build_instance(self): + metadata = EventsMetadata(SdkEventType.FLAG_UPDATE, { "feature1" }) + assert len(metadata.get_names()) == 1 + assert metadata.get_names().pop() == "feature1" + assert len(metadata.get_names()) == 0 + assert metadata.get_type() == SdkEventType.FLAG_UPDATE + + def test_sanitize_none_input(self): + metadata = EventsMetadata(SdkEventType.FLAG_UPDATE, { "feature1", None, 123, False }) + assert len(metadata.get_names()) == 1 + assert metadata.get_names().pop() == "feature1" + assert len(metadata.get_names()) == 0 diff --git a/tests/events/test_events_task.py b/tests/events/test_events_task.py new file mode 100644 index 00000000..d667f76c --- /dev/null +++ b/tests/events/test_events_task.py @@ -0,0 +1,139 @@ +"""EventsManager test module.""" +import pytest +import queue +import time +import asyncio + +from splitio.models.events import SdkInternalEvent +from splitio.models.notification import SdkInternalEventNotification +from splitio.events.events_metadata import EventsMetadata +from splitio.events.events_metadata import SdkEventType +from splitio.events.events_task import EventsTask, EventsTaskAsync + + +class EventsTaskTests(object): + """Tests for EventsTask.""" + + internal_event = None + metadata = None + + def test_firing_events(self): + events_queue = queue.Queue() + events_task = EventsTask(self._event_callback, events_queue) + + events_task.start() + assert events_task.is_running() + + metadata = EventsMetadata(SdkEventType.FLAG_UPDATE, { "feature1" }) + events_queue.put(SdkInternalEventNotification(SdkInternalEvent.SDK_READY, metadata)) + time.sleep(.5) + assert self.internal_event == SdkInternalEvent.SDK_READY + self._verify_metadata(metadata) + + self._reset_flags() + events_queue.put(SdkInternalEventNotification(SdkInternalEvent.RB_SEGMENTS_UPDATED, metadata)) + time.sleep(.5) + assert self.internal_event == SdkInternalEvent.RB_SEGMENTS_UPDATED + self._verify_metadata(metadata) + + events_task.stop() + time.sleep(.5) + assert not events_task.is_running() + + def test_on_error(self): + events_queue = queue.Queue() + + def handler_sync(internal_event, metadata): + raise Exception('some') + + events_task = EventsTask(handler_sync, events_queue) + events_task.start() + assert events_task.is_running() + + events_queue.put(SdkInternalEventNotification(SdkInternalEvent.SDK_READY, None)) + + with pytest.raises(Exception): + events_task._handler() + + assert events_task.is_running() + events_task.stop() + time.sleep(1) + assert not events_task.is_running() + + def _reset_flags(self): + self.internal_event = None + self.metadata = None + + def _event_callback(self, internal_event, metadata): + self.internal_event = internal_event + self.metadata = metadata + + def _verify_metadata(self, metadata): + assert metadata.get_type() == self.metadata.get_type() + assert metadata.get_names() == self.metadata.get_names() + + +class EventsTaskAsyncTests(object): + """Tests for EventsTaskAsyncr.""" + + internal_event = None + metadata = None + + @pytest.mark.asyncio + async def test_firing_events(self): + events_queue = asyncio.Queue() + events_task = EventsTaskAsync(self._event_callback, events_queue) + + events_task.start() + assert events_task.is_running() + + metadata = EventsMetadata(SdkEventType.FLAG_UPDATE, { "feature1" }) + await events_queue.put(SdkInternalEventNotification(SdkInternalEvent.SDK_READY, metadata)) + await asyncio.sleep(.5) + assert self.internal_event == SdkInternalEvent.SDK_READY + self._verify_metadata(metadata) + + self._reset_flags() + await events_queue.put(SdkInternalEventNotification(SdkInternalEvent.RB_SEGMENTS_UPDATED, metadata)) + await asyncio.sleep(.5) + assert self.internal_event == SdkInternalEvent.RB_SEGMENTS_UPDATED + self._verify_metadata(metadata) + + await events_task.stop() + await asyncio.sleep(.5) + assert not events_task.is_running() + + @pytest.mark.asyncio + async def test_on_error(self): + events_queue = asyncio.Queue() + + async def handler_sync(internal_event, metadata): + raise Exception('some') + + events_task = EventsTaskAsync(handler_sync, events_queue) + events_task.start() + assert events_task.is_running() + + await events_queue.put(SdkInternalEventNotification(SdkInternalEvent.SDK_READY, None)) + + with pytest.raises(Exception): + events_task._handler() + + assert events_task.is_running() + await events_task.stop() + await asyncio.sleep(1) + assert not events_task.is_running() + + def _reset_flags(self): + self.internal_event = None + self.metadata = None + + async def _event_callback(self, internal_event, metadata): + self.internal_event = internal_event + self.metadata = metadata + + def _verify_metadata(self, metadata): + assert metadata.get_type() == self.metadata.get_type() + assert metadata.get_names() == self.metadata.get_names() + + \ No newline at end of file diff --git a/tests/helpers/mockserver.py b/tests/helpers/mockserver.py index d85bcfea..8d41cfd2 100644 --- a/tests/helpers/mockserver.py +++ b/tests/helpers/mockserver.py @@ -3,12 +3,13 @@ from collections import namedtuple import queue import threading +import pytest from http.server import HTTPServer, BaseHTTPRequestHandler Request = namedtuple('Request', ['method', 'path', 'headers', 'body']) - +OLD_SPEC = False class SSEMockServer(object): """SSE server for testing purposes.""" @@ -23,8 +24,7 @@ def __init__(self, req_queue=None): self._queue = queue.Queue() self._server = HTTPServer(('localhost', 0), lambda *xs: SSEHandler(self._queue, *xs, req_queue=req_queue)) - self._server_thread = threading.Thread(target=self._blocking_run) - self._server_thread.setDaemon(True) + self._server_thread = threading.Thread(target=self._blocking_run, daemon=True) self._done_event = threading.Event() def _blocking_run(self): @@ -103,21 +103,23 @@ class SplitMockServer(object): protocol_version = 'HTTP/1.1' def __init__(self, split_changes=None, segment_changes=None, req_queue=None, - auth_response=None): + auth_response=None, old_spec=False): """ Consruct a mock server. :param changes: mapping of changeNumbers to splitChanges responses :type changes: dict """ + global OLD_SPEC + OLD_SPEC = old_spec split_changes = split_changes if split_changes is not None else {} segment_changes = segment_changes if segment_changes is not None else {} self._server = HTTPServer(('localhost', 0), lambda *xs: SDKHandler(split_changes, segment_changes, *xs, req_queue=req_queue, - auth_response=auth_response)) - self._server_thread = threading.Thread(target=self._blocking_run, name="SplitMockServer") - self._server_thread.setDaemon(True) + auth_response=auth_response, + )) + self._server_thread = threading.Thread(target=self._blocking_run, name="SplitMockServer", daemon=True) self._done_event = threading.Event() def _blocking_run(self): @@ -150,7 +152,7 @@ def __init__(self, split_changes, segment_changes, *args, **kwargs): self._req_queue = kwargs.get('req_queue') self._auth_response = kwargs.get('auth_response') self._split_changes = split_changes - self._segment_changes = segment_changes + self._segment_changes = segment_changes BaseHTTPRequestHandler.__init__(self, *args) def _parse_qs(self): @@ -182,6 +184,15 @@ def _handle_segment_changes(self): self.wfile.write(json.dumps(to_send).encode('utf-8')) def _handle_split_changes(self): + global OLD_SPEC + if OLD_SPEC: + self.send_response(400) + self.send_header("Content-type", "application/json") + self.end_headers() + self.wfile.write('{}'.encode('utf-8')) + OLD_SPEC = False + return + qstring = self._parse_qs() since = int(qstring.get('since', -1)) to_send = self._split_changes.get(since) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index e69de29b..845e8c72 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -0,0 +1,55 @@ +import copy + +rbsegments_json = [{"changeNumber": 12, "name": "some_segment", "status": "ACTIVE","trafficTypeName": "user","excluded":{"keys":[],"segments":[]},"conditions": []}] + +split11 = {"ff": {"t": 1675443569027, "s": -1, "d": [ + {"trafficTypeName": "user", "name": "SPLIT_2","trafficAllocation": 100,"trafficAllocationSeed": 1057590779, "seed": -113875324, "status": "ACTIVE","killed": False, "defaultTreatment": "off", "changeNumber": 1675443569027,"algo": 2, "configurations": {},"conditions": [{"conditionType": "ROLLOUT","matcherGroup": {"combiner": "AND","matchers": [{"keySelector": { "trafficType": "user", "attribute": None },"matcherType": "ALL_KEYS","negate": False,"userDefinedSegmentMatcherData": None,"whitelistMatcherData": None,"unaryNumericMatcherData": None,"betweenMatcherData": None,"booleanMatcherData": None,"dependencyMatcherData": None,"stringMatcherData": None}]},"partitions": [{ "treatment": "on", "size": 100 },{ "treatment": "off", "size": 0 }],"label": "default rule"}], "sets": ["set_1"], "impressionsDisabled": False, 'prerequisites': []}, + {"trafficTypeName": "user", "name": "SPLIT_1", "trafficAllocation": 100, "trafficAllocationSeed": -1780071202,"seed": -1442762199, "status": "ACTIVE","killed": False, "defaultTreatment": "off", "changeNumber": 1675443537882,"algo": 2, "configurations": {},"conditions": [{"conditionType": "ROLLOUT", "matcherGroup": {"combiner": "AND","matchers": [{"keySelector": { "trafficType": "user", "attribute": None },"matcherType": "ALL_KEYS","negate": False,"userDefinedSegmentMatcherData": None,"whitelistMatcherData": None,"unaryNumericMatcherData": None,"betweenMatcherData": None,"booleanMatcherData": None,"dependencyMatcherData": None,"stringMatcherData": None}]},"partitions": [{ "treatment": "on", "size": 0 },{ "treatment": "off", "size": 100 }],"label": "default rule"}], "sets": ["set_1", "set_2"]}, + {"trafficTypeName": "user", "name": "SPLIT_3","trafficAllocation": 100,"trafficAllocationSeed": 1057590779, "seed": -113875324, "status": "ACTIVE","killed": False, "defaultTreatment": "off", "changeNumber": 1675443569027,"algo": 2, "configurations": {},"conditions": [{"conditionType": "ROLLOUT","matcherGroup": {"combiner": "AND","matchers": [{"keySelector": { "trafficType": "user", "attribute": None },"matcherType": "ALL_KEYS","negate": False,"userDefinedSegmentMatcherData": None,"whitelistMatcherData": None,"unaryNumericMatcherData": None,"betweenMatcherData": None,"booleanMatcherData": None,"dependencyMatcherData": None,"stringMatcherData": None}]},"partitions": [{ "treatment": "on", "size": 100 },{ "treatment": "off", "size": 0 }],"label": "default rule"}], "sets": ["set_1"], "impressionsDisabled": True} + ]}, "rbs": {"t": -1, "s": -1, "d": rbsegments_json}} +split12 = {"ff": {"s": 1675443569027,"t": 1675443767284, "d": [{"trafficTypeName": "user","name": "SPLIT_2","trafficAllocation": 100,"trafficAllocationSeed": 1057590779,"seed": -113875324,"status": "ACTIVE","killed": True,"defaultTreatment": "off","changeNumber": 1675443767288,"algo": 2,"configurations": {},"conditions": [{"conditionType": "ROLLOUT","matcherGroup": {"combiner": "AND","matchers": [{"keySelector": { "trafficType": "user", "attribute": None },"matcherType": "ALL_KEYS","negate": False,"userDefinedSegmentMatcherData": None,"whitelistMatcherData": None,"unaryNumericMatcherData": None,"betweenMatcherData": None,"booleanMatcherData": None,"dependencyMatcherData": None,"stringMatcherData": None}]},"partitions": [{ "treatment": "on", "size": 100 },{ "treatment": "off", "size": 0 }],"label": "default rule"}]}]}, "rbs": {"t": -1, "s": -1, "d": rbsegments_json}} +split13 = {"ff": {"s": 1675443767288,"t": 1675443984594, "d": [ + {"trafficTypeName": "user","name": "SPLIT_1","trafficAllocation": 100,"trafficAllocationSeed": -1780071202,"seed": -1442762199,"status": "ARCHIVED","killed": False,"defaultTreatment": "off","changeNumber": 1675443984594,"algo": 2,"configurations": {},"conditions": [{"conditionType": "ROLLOUT","matcherGroup": {"combiner": "AND","matchers": [{"keySelector": { "trafficType": "user", "attribute": None },"matcherType": "ALL_KEYS","negate": False,"userDefinedSegmentMatcherData": None,"whitelistMatcherData": None,"unaryNumericMatcherData": None,"betweenMatcherData": None,"booleanMatcherData": None,"dependencyMatcherData": None,"stringMatcherData": None}]},"partitions": [{ "treatment": "on", "size": 0 },{ "treatment": "off", "size": 100 }],"label": "default rule"}]}, + {"trafficTypeName": "user","name": "SPLIT_2","trafficAllocation": 100,"trafficAllocationSeed": 1057590779,"seed": -113875324,"status": "ACTIVE","killed": False,"defaultTreatment": "off","changeNumber": 1675443954220,"algo": 2,"configurations": {},"conditions": [{"conditionType": "ROLLOUT","matcherGroup": {"combiner": "AND","matchers": [{"keySelector": { "trafficType": "user", "attribute": None },"matcherType": "ALL_KEYS","negate": False,"userDefinedSegmentMatcherData": None,"whitelistMatcherData": None,"unaryNumericMatcherData": None,"betweenMatcherData": None,"booleanMatcherData": None,"dependencyMatcherData": None,"stringMatcherData": None}]},"partitions": [{ "treatment": "on", "size": 100 },{ "treatment": "off", "size": 0 }],"label": "default rule"}]} + ]}, "rbs": {"t": -1, "s": -1, "d": rbsegments_json}} + + +split41 = {"ff": {"t": None, "s": None, "d": split11['ff']['d']}, "rbs": {"t": -1, "s": -1, "d": rbsegments_json}} +split42 = {"ff": {"t": None, "s": None, "d": split12['ff']['d']}, "rbs": {"t": -1, "s": -1, "d": rbsegments_json}} +split43 = {"ff": {"t": None, "s": None, "d": split13['ff']['d']}, "rbs": {"t": -1, "s": -1, "d": rbsegments_json}} + +split61 = {"ff": {"t": -1, "s": -1, "d": split11['ff']['d']}, "rbs": {"t": -1, "s": -1, "d": rbsegments_json}} +split62 = {"ff": {"t": -1, "s": -1, "d": split12['ff']['d']}, "rbs": {"t": -1, "s": -1, "d": rbsegments_json}} +split63 = {"ff": {"t": -1, "s": -1, "d": split13['ff']['d']}, "rbs": {"t": -1, "s": -1, "d": rbsegments_json}} + +splits_json = { + "splitChange1_1": split11, + "splitChange1_2": split12, + "splitChange1_3": split13, + "splitChange2_1": {"ff": {"t": -1, "s": -1, "d": [{"name": "SPLIT_1","status": "ACTIVE","killed": False,"defaultTreatment": "off","configurations": {},"conditions": []}]}, "rbs": {"t": -1, "s": -1, "d": rbsegments_json}}, + "splitChange3_1": {"ff": {"t": -1, "s": -1, "d": [{"trafficTypeName": "user","name": "SPLIT_2","trafficAllocation": 100,"trafficAllocationSeed": 1057590779,"seed": -113875324,"status": "ACTIVE","killed": False,"defaultTreatment": "off","changeNumber": 1675443569027,"algo": 2,"configurations": {},"conditions": [{"conditionType": "ROLLOUT","matcherGroup": {"combiner": "AND","matchers": [{"keySelector": { "trafficType": "user", "attribute": None },"matcherType": "ALL_KEYS","negate": False,"userDefinedSegmentMatcherData": None,"whitelistMatcherData": None,"unaryNumericMatcherData": None,"betweenMatcherData": None,"booleanMatcherData": None,"dependencyMatcherData": None,"stringMatcherData": None}]},"partitions": [{ "treatment": "on", "size": 100 },{ "treatment": "off", "size": 0 }],"label": "default rule"}]}],"s": -1,"t": 1675443569027}, "rbs": {"t": -1, "s": -1, "d": rbsegments_json}}, + "splitChange3_2": {"ff": {"t": -1, "s": -1, "d": [{"trafficTypeName": "user","name": "SPLIT_2","trafficAllocation": 100,"trafficAllocationSeed": 1057590779,"seed": -113875324,"status": "ACTIVE","killed": True,"defaultTreatment": "off","changeNumber": 1675443767288,"algo": 2,"configurations": {},"conditions": [{"conditionType": "ROLLOUT","matcherGroup": {"combiner": "AND","matchers": [{"keySelector": { "trafficType": "user", "attribute": None },"matcherType": "ALL_KEYS","negate": False,"userDefinedSegmentMatcherData": None,"whitelistMatcherData": None,"unaryNumericMatcherData": None,"betweenMatcherData": None,"booleanMatcherData": None,"dependencyMatcherData": None,"stringMatcherData": None}]},"partitions": [{ "treatment": "on", "size": 100 },{ "treatment": "off", "size": 0 }],"label": "default rule"}]}],"s": 1675443569027,"t": 1675443569027}, "rbs": {"t": -1, "s": -1, "d": rbsegments_json}}, + "splitChange4_1": split41, + "splitChange4_2": split42, + "splitChange4_3": split43, + "splitChange5_1": {"ff": {"t": -1, "s": -1, "d": [{"trafficTypeName": "user","name": "SPLIT_2","trafficAllocation": 100,"trafficAllocationSeed": 1057590779,"seed": -113875324,"status": "ACTIVE","killed": False,"defaultTreatment": "off","changeNumber": 1675443569027,"algo": 2,"configurations": {},"conditions": [{"conditionType": "ROLLOUT","matcherGroup": {"combiner": "AND","matchers": [{"keySelector": { "trafficType": "user", "attribute": None },"matcherType": "ALL_KEYS","negate": False,"userDefinedSegmentMatcherData": None,"whitelistMatcherData": None,"unaryNumericMatcherData": None,"betweenMatcherData": None,"booleanMatcherData": None,"dependencyMatcherData": None,"stringMatcherData": None}]},"partitions": [{ "treatment": "on", "size": 100 },{ "treatment": "off", "size": 0 }],"label": "default rule"}]}],"s": -1,"t": 1675443569027}, "rbs": {"t": -1, "s": -1, "d": rbsegments_json}}, + "splitChange5_2": {"ff": {"t": -1, "s": -1, "d": [{"trafficTypeName": "user","name": "SPLIT_2","trafficAllocation": 100,"trafficAllocationSeed": 1057590779,"seed": -113875324,"status": "ACTIVE","killed": True,"defaultTreatment": "off","changeNumber": 1675443767288,"algo": 2,"configurations": {},"conditions": [{"conditionType": "ROLLOUT","matcherGroup": {"combiner": "AND","matchers": [{"keySelector": { "trafficType": "user", "attribute": None },"matcherType": "ALL_KEYS","negate": False,"userDefinedSegmentMatcherData": None,"whitelistMatcherData": None,"unaryNumericMatcherData": None,"betweenMatcherData": None,"booleanMatcherData": None,"dependencyMatcherData": None,"stringMatcherData": None}]},"partitions": [{ "treatment": "on", "size": 100 },{ "treatment": "off", "size": 0 }],"label": "default rule"}]}],"s": 1675443569026,"t": 1675443569026}, "rbs": {"t": -1, "s": -1, "d": rbsegments_json}}, + "splitChange6_1": split61, + "splitChange6_2": split62, + "splitChange6_3": split63, + "splitChange7_1": {"ff": { + "t": -1, + "s": -1, + "d": [{"changeNumber": 10,"trafficTypeName": "user","name": "rbs_feature_flag","trafficAllocation": 100,"trafficAllocationSeed": 1828377380,"seed": -286617921,"status": "ACTIVE","killed": False,"defaultTreatment": "off","algo": 2, + "conditions": [{"conditionType": "ROLLOUT","matcherGroup": {"combiner": "AND","matchers": [{"keySelector": {"trafficType": "user"},"matcherType": "IN_RULE_BASED_SEGMENT","negate": False,"userDefinedSegmentMatcherData": {"segmentName": "sample_rule_based_segment"}}]},"partitions": [{"treatment": "on","size": 100},{"treatment": "off","size": 0}],"label": "in rule based segment sample_rule_based_segment"},{"conditionType": "ROLLOUT","matcherGroup": {"combiner": "AND","matchers": [{"keySelector": {"trafficType": "user"},"matcherType": "ALL_KEYS","negate": False}]},"partitions": [{"treatment": "on","size": 0},{"treatment": "off","size": 100}],"label": "default rule"}], + "configurations": {}, + "sets": [], + "impressionsDisabled": False + }] + }, "rbs": { + "t": 1675259356568, + "s": -1, + "d": [{"changeNumber": 5,"name": "sample_rule_based_segment","status": "ACTIVE","trafficTypeName": "user","excluded":{"keys":["mauro@split.io","gaston@split.io"],"segments":[]}, + "conditions": [{"matcherGroup": {"combiner": "AND","matchers": [{"keySelector": {"trafficType": "user","attribute": "email"},"matcherType": "ENDS_WITH","negate": False,"whitelistMatcherData": {"whitelist": ["@split.io"]}}]}}]} + ]}} +} \ No newline at end of file diff --git a/tests/integration/files/splitChanges.json b/tests/integration/files/splitChanges.json index d5401c93..84f7c2cd 100644 --- a/tests/integration/files/splitChanges.json +++ b/tests/integration/files/splitChanges.json @@ -1,5 +1,6 @@ { - "splits": [ + "ff": { + "d": [ { "orgId": null, "environment": null, @@ -22,7 +23,8 @@ "userDefinedSegmentMatcherData": null, "whitelistMatcherData": { "whitelist": [ - "whitelisted_user" + "whitelisted_user", + "user1234" ] } } @@ -58,7 +60,8 @@ } ] } - ] + ], + "sets": ["set1", "set2"] }, { "orgId": null, @@ -95,7 +98,8 @@ } ] } - ] + ], + "sets": ["set4"] }, { "orgId": null, @@ -136,7 +140,8 @@ } ] } - ] + ], + "sets": ["set3"] }, { "orgId": null, @@ -199,7 +204,8 @@ } ] } - ] + ], + "sets": ["set1"] }, { "orgId": null, @@ -239,7 +245,8 @@ } ] } - ] + ], + "sets": [] }, { "orgId": null, @@ -276,7 +283,8 @@ } ] } - ] + ], + "sets": [] }, { "orgId": null, @@ -313,9 +321,156 @@ } ] } + ], + "sets": [] + }, + { + "changeNumber": 10, + "trafficTypeName": "user", + "name": "rbs_feature_flag", + "trafficAllocation": 100, + "trafficAllocationSeed": 1828377380, + "seed": -286617921, + "status": "ACTIVE", + "killed": false, + "defaultTreatment": "off", + "algo": 2, + "conditions": [ + { + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user" + }, + "matcherType": "IN_RULE_BASED_SEGMENT", + "negate": false, + "userDefinedSegmentMatcherData": { + "segmentName": "sample_rule_based_segment" + } + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ], + "label": "in rule based segment sample_rule_based_segment" + }, + { + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user" + }, + "matcherType": "ALL_KEYS", + "negate": false + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 0 + }, + { + "treatment": "off", + "size": 100 + } + ], + "label": "default rule" + } + ], + "configurations": {}, + "sets": [], + "impressionsDisabled": false + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "prereq_feature", + "seed": 1699838640, + "status": "ACTIVE", + "killed": false, + "changeNumber": 123, + "defaultTreatment": "off_default", + "conditions": [ + { + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "ALL_KEYS", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ] + } + ], + "sets": [], + "prerequisites": [ + {"n": "regex_test", "ts": ["on"]}, + {"n": "whitelist_feature", "ts": ["off"]} ] } ], - "since": -1, - "till": 1457726098069 -} + "s": -1, + "t": 1457726098069 +}, "rbs": {"t": -1, "s": -1, "d": [{ + "changeNumber": 123, + "name": "sample_rule_based_segment", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded":{ + "keys":["mauro@split.io","gaston@split.io"], + "segments":[] + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user", + "attribute": "email" + }, + "matcherType": "ENDS_WITH", + "negate": false, + "whitelistMatcherData": { + "whitelist": [ + "@split.io" + ] + } + } + ] + } + } + ] +}]}} diff --git a/tests/integration/files/split_changes.json b/tests/integration/files/split_changes.json index f536346d..f0708043 100644 --- a/tests/integration/files/split_changes.json +++ b/tests/integration/files/split_changes.json @@ -1,5 +1,6 @@ { - "splits": [ + "ff": { + "d": [ { "orgId": null, "environment": null, @@ -58,7 +59,8 @@ } ] } - ] + ], + "sets": ["set1", "set2"] }, { "orgId": null, @@ -95,7 +97,8 @@ } ] } - ] + ], + "sets": ["set4"] }, { "orgId": null, @@ -136,7 +139,8 @@ } ] } - ] + ], + "sets": ["set3"] }, { "orgId": null, @@ -199,7 +203,8 @@ } ] } - ] + ], + "sets": ["set1"] }, { "orgId": null, @@ -239,7 +244,8 @@ } ] } - ] + ], + "sets": [] }, { "orgId": null, @@ -276,7 +282,8 @@ } ] } - ] + ], + "sets": [] }, { "orgId": null, @@ -313,9 +320,11 @@ } ] } - ] + ], + "sets": [] } ], - "since": -1, - "till": 1457726098069 + "s": -1, + "t": 1457726098069 +}, "rbs": {"t": -1, "s": -1, "d": []} } diff --git a/tests/integration/files/split_changes_temp.json b/tests/integration/files/split_changes_temp.json new file mode 100644 index 00000000..24d876a4 --- /dev/null +++ b/tests/integration/files/split_changes_temp.json @@ -0,0 +1 @@ +{"ff": {"t": -1, "s": -1, "d": [{"name": "SPLIT_1", "status": "ACTIVE", "killed": false, "defaultTreatment": "off", "configurations": {}, "conditions": []}]}, "rbs": {"t": -1, "s": -1, "d": [{"changeNumber": 12, "name": "some_segment", "status": "ACTIVE", "trafficTypeName": "user", "excluded": {"keys": [], "segments": []}, "conditions": []}]}} \ No newline at end of file diff --git a/tests/integration/files/split_old_spec.json b/tests/integration/files/split_old_spec.json new file mode 100644 index 00000000..0d7edf86 --- /dev/null +++ b/tests/integration/files/split_old_spec.json @@ -0,0 +1,328 @@ +{ + "splits": [ + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "whitelist_feature", + "seed": -1222652054, + "status": "ACTIVE", + "killed": false, + "changeNumber": 123, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "WHITELIST", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": { + "whitelist": [ + "whitelisted_user" + ] + } + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + } + ] + }, + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "ALL_KEYS", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 0 + }, + { + "treatment": "off", + "size": 100 + } + ] + } + ], + "sets": ["set1", "set2"] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "all_feature", + "seed": 1699838640, + "status": "ACTIVE", + "killed": false, + "changeNumber": 123, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "ALL_KEYS", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ] + } + ], + "sets": ["set4"] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "killed_feature", + "seed": -480091424, + "status": "ACTIVE", + "killed": true, + "changeNumber": 123, + "defaultTreatment": "defTreatment", + "configurations": { + "off": "{\"size\":15,\"test\":20}", + "defTreatment": "{\"size\":15,\"defTreatment\":true}" + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "ALL_KEYS", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "defTreatment", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ] + } + ], + "sets": ["set3"] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "sample_feature", + "seed": 1548363147, + "status": "ACTIVE", + "killed": false, + "changeNumber": 123, + "defaultTreatment": "off", + "configurations": { + "on": "{\"size\":15,\"test\":20}" + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SEGMENT", + "negate": false, + "userDefinedSegmentMatcherData": { + "segmentName": "employees" + }, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + } + ] + }, + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SEGMENT", + "negate": false, + "userDefinedSegmentMatcherData": { + "segmentName": "human_beigns" + }, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 30 + }, + { + "treatment": "off", + "size": 70 + } + ] + } + ], + "sets": ["set1"] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "dependency_test", + "seed": 1222652054, + "status": "ACTIVE", + "killed": false, + "changeNumber": 123, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SPLIT_TREATMENT", + "negate": false, + "userDefinedSegmentMatcherData": null, + "dependencyMatcherData": { + "split": "all_feature", + "treatments": ["on"] + } + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 0 + }, + { + "treatment": "off", + "size": 100 + } + ] + } + ], + "sets": [] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "regex_test", + "seed": 1222652051, + "status": "ACTIVE", + "killed": false, + "changeNumber": 123, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "MATCHES_STRING", + "negate": false, + "userDefinedSegmentMatcherData": null, + "stringMatcherData": "abc[0-9]" + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ] + } + ], + "sets": [] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "boolean_test", + "status": "ACTIVE", + "killed": false, + "changeNumber": 123, + "seed": 12321809, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "EQUAL_TO_BOOLEAN", + "negate": false, + "userDefinedSegmentMatcherData": null, + "booleanMatcherData": true + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ] + } + ], + "sets": [] + } + ], + "since": -1, + "till": 1457726098069 +} \ No newline at end of file diff --git a/tests/integration/test_client_e2e.py b/tests/integration/test_client_e2e.py index 50ea1cae..26efcd42 100644 --- a/tests/integration/test_client_e2e.py +++ b/tests/integration/test_client_e2e.py @@ -1,37 +1,539 @@ """Client integration tests.""" # pylint: disable=protected-access,line-too-long,no-self-use +from asyncio import Queue import json import os import threading - +import time +import pytest +import queue +import unittest.mock as mocker from redis import StrictRedis -from splitio.client.factory import get_factory, SplitFactory +from splitio.optional.loaders import asyncio +from splitio.exceptions import TimeoutException +from splitio.client.factory import get_factory, SplitFactory, get_factory_async, SplitFactoryAsync from splitio.client.util import SdkMetadata +from splitio.client.config import DEFAULT_CONFIG +from splitio.client.client import EvaluationOptions +from splitio.engine.impressions.impressions import Manager as ImpressionsManager, ImpressionsMode +from splitio.engine.impressions import set_classes, set_classes_async +from splitio.engine.impressions.strategies import StrategyDebugMode, StrategyOptimizedMode, StrategyNoneMode +from splitio.engine.telemetry import TelemetryStorageConsumer, TelemetryStorageProducer, TelemetryStorageConsumerAsync,\ + TelemetryStorageProducerAsync +from splitio.engine.impressions.manager import Counter as ImpressionsCounter +from splitio.engine.impressions.unique_keys_tracker import UniqueKeysTracker, UniqueKeysTrackerAsync +from splitio.events.events_delivery import EventsDelivery +from splitio.events.events_manager import EventsManager, EventsManagerAsync +from splitio.events.events_manager_config import EventsManagerConfig +from splitio.events.events_task import EventsTask, EventsTaskAsync +from splitio.models import splits, segments, rule_based_segments +from splitio.models.events import SdkEvent +from splitio.models.fallback_config import FallbackTreatmentsConfiguration, FallbackTreatmentCalculator +from splitio.models.fallback_treatment import FallbackTreatment +from splitio.recorder.recorder import StandardRecorder, PipelinedRecorder, StandardRecorderAsync, PipelinedRecorderAsync from splitio.storage.inmemmory import InMemoryEventStorage, InMemoryImpressionStorage, \ - InMemorySegmentStorage, InMemorySplitStorage + InMemorySegmentStorage, InMemorySplitStorage, InMemoryTelemetryStorage, InMemorySplitStorageAsync,\ + InMemoryEventStorageAsync, InMemoryImpressionStorageAsync, InMemorySegmentStorageAsync, \ + InMemoryTelemetryStorageAsync, InMemoryRuleBasedSegmentStorage, InMemoryRuleBasedSegmentStorageAsync from splitio.storage.redis import RedisEventsStorage, RedisImpressionsStorage, \ - RedisSplitStorage, RedisSegmentStorage -from splitio.storage.adapters.redis import build, RedisAdapter -from splitio.models import splits, segments -from splitio.engine.impressions import Manager as ImpressionsManager, ImpressionsMode -from splitio.recorder.recorder import StandardRecorder, PipelinedRecorder -from splitio.client.config import DEFAULT_CONFIG + RedisSplitStorage, RedisSegmentStorage, RedisTelemetryStorage, RedisEventsStorageAsync,\ + RedisImpressionsStorageAsync, RedisSegmentStorageAsync, RedisSplitStorageAsync, RedisTelemetryStorageAsync, \ + RedisRuleBasedSegmentsStorage, RedisRuleBasedSegmentsStorageAsync +from splitio.storage.pluggable import PluggableEventsStorage, PluggableImpressionsStorage, PluggableSegmentStorage, \ + PluggableTelemetryStorage, PluggableSplitStorage, PluggableEventsStorageAsync, PluggableImpressionsStorageAsync, \ + PluggableSegmentStorageAsync, PluggableSplitStorageAsync, PluggableTelemetryStorageAsync, \ + PluggableRuleBasedSegmentsStorage, PluggableRuleBasedSegmentsStorageAsync +from splitio.storage.adapters.redis import build, RedisAdapter, RedisAdapterAsync, build_async +from splitio.sync.synchronizer import SplitTasks, SplitSynchronizers, Synchronizer, RedisSynchronizer, SynchronizerAsync,\ +RedisSynchronizerAsync +from splitio.sync.manager import Manager, RedisManager, ManagerAsync, RedisManagerAsync +from splitio.sync.synchronizer import PluggableSynchronizer, PluggableSynchronizerAsync +from splitio.sync.telemetry import RedisTelemetrySubmitter, RedisTelemetrySubmitterAsync + +from tests.helpers.mockserver import SplitMockServer +from tests.integration import splits_json +from tests.storage.test_pluggable import StorageMockAdapter, StorageMockAdapterAsync + +def _validate_last_impressions(client, *to_validate): + """Validate the last N impressions are present disregarding the order.""" + imp_storage = client._factory._get_storage('impressions') + as_tup_set = set() + if isinstance(client._factory._get_storage('splits'), RedisSplitStorage) or isinstance(client._factory._get_storage('splits'), PluggableSplitStorage): + if isinstance(client._factory._get_storage('splits'), RedisSplitStorage): + redis_client = imp_storage._redis + impressions_raw = [ + json.loads(redis_client.lpop(imp_storage.IMPRESSIONS_QUEUE_KEY)) + for _ in to_validate + ] + else: + pluggable_adapter = imp_storage._pluggable_adapter + results = pluggable_adapter.pop_items(imp_storage._impressions_queue_key) + results = [] if results == None else results + impressions_raw = [ + json.loads(i) + for i in results + ] + if to_validate != (): + if len(to_validate[0]) == 3: + as_tup_set = set( + (i['i']['f'], i['i']['k'], i['i']['t']) + for i in impressions_raw + ) + else: + as_tup_set = set( + (i['i']['f'], i['i']['k'], i['i']['t'], i['i']['properties']) + for i in impressions_raw + ) + + assert as_tup_set == set(to_validate) + time.sleep(0.2) # delay for redis to sync + else: + impressions = imp_storage.pop_many(len(to_validate)) + if to_validate != (): + if len(to_validate[0]) == 3: + as_tup_set = set((i.feature_name, i.matching_key, i.treatment) for i in impressions) + else: + as_tup_set = set((i.feature_name, i.matching_key, i.treatment, i.properties) for i in impressions) + + assert as_tup_set == set(to_validate) + +def _validate_last_events(client, *to_validate): + """Validate the last N impressions are present disregarding the order.""" + event_storage = client._factory._get_storage('events') + if isinstance(client._factory._get_storage('splits'), RedisSplitStorage) or isinstance(client._factory._get_storage('splits'), PluggableSplitStorage): + if isinstance(client._factory._get_storage('splits'), RedisSplitStorage): + redis_client = event_storage._redis + events_raw = [ + json.loads(redis_client.lpop(event_storage._EVENTS_KEY_TEMPLATE)) + for _ in to_validate + ] + else: + pluggable_adapter = event_storage._pluggable_adapter + events_raw = [ + json.loads(i) + for i in pluggable_adapter.pop_items(event_storage._events_queue_key) + ] + as_tup_set = set( + (i['e']['key'], i['e']['trafficTypeName'], i['e']['eventTypeId'], i['e']['value'], str(i['e']['properties'])) + for i in events_raw + ) + assert as_tup_set == set(to_validate) + else: + events = event_storage.pop_many(len(to_validate)) + as_tup_set = set((i.key, i.traffic_type_name, i.event_type_id, i.value, str(i.properties)) for i in events) + assert as_tup_set == set(to_validate) + +def _get_treatment(factory, skip_rbs=False): + """Test client.get_treatment().""" + try: + client = factory.client() + except: + pass + + assert client.get_treatment('user1', 'sample_feature', evaluation_options=EvaluationOptions({"prop": "value"})) == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('sample_feature', 'user1', 'on', '{"prop": "value"}')) + + assert client.get_treatment('invalidKey', 'sample_feature') == 'off' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) + + assert client.get_treatment('invalidKey', 'invalid_feature') == 'control' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client) # No impressions should be present + + # testing a killed feature. No matter what the key, must return default treatment + assert client.get_treatment('invalidKey', 'killed_feature') == 'defTreatment' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + assert client.get_treatment('invalidKey', 'all_feature') == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + + # testing WHITELIST matcher + assert client.get_treatment('whitelisted_user', 'whitelist_feature') == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('whitelist_feature', 'whitelisted_user', 'on')) + assert client.get_treatment('unwhitelisted_user', 'whitelist_feature') == 'off' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('whitelist_feature', 'unwhitelisted_user', 'off')) + + # testing INVALID matcher + assert client.get_treatment('some_user_key', 'invalid_matcher_feature') == 'control' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client) # No impressions should be present + + # testing Dependency matcher + assert client.get_treatment('somekey', 'dependency_test') == 'off' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('dependency_test', 'somekey', 'off')) + + # testing boolean matcher + assert client.get_treatment('True', 'boolean_test') == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('boolean_test', 'True', 'on')) + + # testing regex matcher + assert client.get_treatment('abc4', 'regex_test') == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('regex_test', 'abc4', 'on')) + + if skip_rbs: + return + + # test rule based segment matcher + assert client.get_treatment('bilal@split.io', 'rbs_feature_flag', {'email': 'bilal@split.io'}) == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('rbs_feature_flag', 'bilal@split.io', 'on')) + + # test rule based segment matcher + assert client.get_treatment('mauro@split.io', 'rbs_feature_flag', {'email': 'mauro@split.io'}) == 'off' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('rbs_feature_flag', 'mauro@split.io', 'off')) + + # test prerequisites matcher + assert client.get_treatment('abc4', 'prereq_feature') == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('prereq_feature', 'abc4', 'on')) + + # test prerequisites matcher + assert client.get_treatment('user1234', 'prereq_feature') == 'off_default' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('prereq_feature', 'user1234', 'off_default')) + + # test fallback treatment + assert client.get_treatment('user4321', 'fallback_feature') == 'on-local' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client) # No impressions should be present + +def _get_treatment_with_config(factory): + """Test client.get_treatment_with_config().""" + try: + client = factory.client() + except: + pass + result = client.get_treatment_with_config('user1', 'sample_feature') + assert result == ('on', '{"size":15,"test":20}') + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('sample_feature', 'user1', 'on')) + + result = client.get_treatment_with_config('invalidKey', 'sample_feature') + assert result == ('off', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) + + result = client.get_treatment_with_config('invalidKey', 'invalid_feature') + assert result == ('control', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client) + + # testing a killed feature. No matter what the key, must return default treatment + result = client.get_treatment_with_config('invalidKey', 'killed_feature') + assert ('defTreatment', '{"size":15,"defTreatment":true}') == result + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = client.get_treatment_with_config('invalidKey', 'all_feature') + assert result == ('on', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + + # test fallback treatment + assert client.get_treatment_with_config('user4321', 'fallback_feature') == ('on-local', '{"prop": "val"}') + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client) # No impressions should be present + +def _get_treatments(factory): + """Test client.get_treatments().""" + try: + client = factory.client() + except: + pass + result = client.get_treatments('user1', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('sample_feature', 'user1', 'on')) + + result = client.get_treatments('invalidKey', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == 'off' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) + + result = client.get_treatments('invalidKey', ['invalid_feature']) + assert len(result) == 1 + assert result['invalid_feature'] == 'control' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client) + + # testing a killed feature. No matter what the key, must return default treatment + result = client.get_treatments('invalidKey', ['killed_feature']) + assert len(result) == 1 + assert result['killed_feature'] == 'defTreatment' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = client.get_treatments('invalidKey', ['all_feature']) + assert len(result) == 1 + assert result['all_feature'] == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + + # test fallback treatment + assert client.get_treatments('user4321', ['fallback_feature']) == {'fallback_feature': 'on-local'} + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client) # No impressions should be present + +def _get_treatments_with_config(factory): + """Test client.get_treatments_with_config().""" + try: + client = factory.client() + except: + pass + + result = client.get_treatments_with_config('user1', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == ('on', '{"size":15,"test":20}') + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('sample_feature', 'user1', 'on')) + + result = client.get_treatments_with_config('invalidKey', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == ('off', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) + + result = client.get_treatments_with_config('invalidKey', ['invalid_feature']) + assert len(result) == 1 + assert result['invalid_feature'] == ('control', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client) + + # testing a killed feature. No matter what the key, must return default treatment + result = client.get_treatments_with_config('invalidKey', ['killed_feature']) + assert len(result) == 1 + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = client.get_treatments_with_config('invalidKey', ['all_feature']) + assert len(result) == 1 + assert result['all_feature'] == ('on', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + + # test fallback treatment + assert client.get_treatments_with_config('user4321', ['fallback_feature']) == {'fallback_feature': ('on-local', '{"prop": "val"}')} + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client) # No impressions should be present + +def _get_treatments_by_flag_set(factory): + """Test client.get_treatments_by_flag_set().""" + try: + client = factory.client() + except: + pass + result = client.get_treatments_by_flag_set('user1', 'set1') + assert len(result) == 2 + assert result == {'sample_feature': 'on', 'whitelist_feature': 'off'} + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('sample_feature', 'user1', 'on'), ('whitelist_feature', 'user1', 'off')) + + result = client.get_treatments_by_flag_set('invalidKey', 'invalid_set') + assert len(result) == 0 + assert result == {} + + # testing a killed feature. No matter what the key, must return default treatment + result = client.get_treatments_by_flag_set('invalidKey', 'set3') + assert len(result) == 1 + assert result['killed_feature'] == 'defTreatment' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = client.get_treatments_by_flag_set('invalidKey', 'set4') + assert len(result) == 1 + assert result['all_feature'] == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + +def _get_treatments_by_flag_sets(factory): + """Test client.get_treatments_by_flag_sets().""" + try: + client = factory.client() + except: + pass + result = client.get_treatments_by_flag_sets('user1', ['set1']) + assert len(result) == 2 + assert result == {'sample_feature': 'on', 'whitelist_feature': 'off'} + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('sample_feature', 'user1', 'on'), ('whitelist_feature', 'user1', 'off')) + + result = client.get_treatments_by_flag_sets('invalidKey', ['invalid_set']) + assert len(result) == 0 + assert result == {} + + result = client.get_treatments_by_flag_sets('invalidKey', []) + assert len(result) == 0 + assert result == {} + + # testing a killed feature. No matter what the key, must return default treatment + result = client.get_treatments_by_flag_sets('invalidKey', ['set3']) + assert len(result) == 1 + assert result['killed_feature'] == 'defTreatment' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = client.get_treatments_by_flag_sets('user1', ['set4']) + assert len(result) == 1 + assert result['all_feature'] == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('all_feature', 'user1', 'on')) + +def _get_treatments_with_config_by_flag_set(factory): + """Test client.get_treatments_with_config_by_flag_set().""" + try: + client = factory.client() + except: + pass + result = client.get_treatments_with_config_by_flag_set('user1', 'set1') + assert len(result) == 2 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), 'whitelist_feature': ('off', None)} + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('sample_feature', 'user1', 'on'), ('whitelist_feature', 'user1', 'off')) + + result = client.get_treatments_with_config_by_flag_set('invalidKey', 'invalid_set') + assert len(result) == 0 + assert result == {} + + # testing a killed feature. No matter what the key, must return default treatment + result = client.get_treatments_with_config_by_flag_set('invalidKey', 'set3') + assert len(result) == 1 + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = client.get_treatments_with_config_by_flag_set('invalidKey', 'set4') + assert len(result) == 1 + assert result['all_feature'] == ('on', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + +def _get_treatments_with_config_by_flag_sets(factory): + """Test client.get_treatments_with_config_by_flag_sets().""" + try: + client = factory.client() + except: + pass + result = client.get_treatments_with_config_by_flag_sets('user1', ['set1']) + assert len(result) == 2 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), 'whitelist_feature': ('off', None)} + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('sample_feature', 'user1', 'on'), ('whitelist_feature', 'user1', 'off')) + + result = client.get_treatments_with_config_by_flag_sets('invalidKey', ['invalid_set']) + assert len(result) == 0 + assert result == {} + + result = client.get_treatments_with_config_by_flag_sets('invalidKey', []) + assert len(result) == 0 + assert result == {} + + # testing a killed feature. No matter what the key, must return default treatment + result = client.get_treatments_with_config_by_flag_sets('invalidKey', ['set3']) + assert len(result) == 1 + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = client.get_treatments_with_config_by_flag_sets('user1', ['set4']) + assert len(result) == 1 + assert result['all_feature'] == ('on', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + _validate_last_impressions(client, ('all_feature', 'user1', 'on')) + +def _track(factory): + """Test client.track().""" + try: + client = factory.client() + except: + pass + assert(client.track('user1', 'user', 'conversion', 1, {"prop1": "value1"})) + assert(not client.track(None, 'user', 'conversion')) + assert(not client.track('user1', None, 'conversion')) + assert(not client.track('user1', 'user', None)) + _validate_last_events( + client, + ('user1', 'user', 'conversion', 1, "{'prop1': 'value1'}") + ) + +def _manager_methods(factory, skip_rbs=False): + """Test manager.split/splits.""" + try: + manager = factory.manager() + except: + pass + result = manager.split('all_feature') + assert result.name == 'all_feature' + assert result.traffic_type is None + assert result.killed is False + assert len(result.treatments) == 2 + assert result.change_number == 123 + assert result.configs == {} + + result = manager.split('killed_feature') + assert result.name == 'killed_feature' + assert result.traffic_type is None + assert result.killed is True + assert len(result.treatments) == 2 + assert result.change_number == 123 + assert result.configs['defTreatment'] == '{"size":15,"defTreatment":true}' + assert result.configs['off'] == '{"size":15,"test":20}' + + result = manager.split('sample_feature') + assert result.name == 'sample_feature' + assert result.traffic_type is None + assert result.killed is False + assert len(result.treatments) == 2 + assert result.change_number == 123 + assert result.configs['on'] == '{"size":15,"test":20}' + + if skip_rbs: + assert len(manager.split_names()) == 7 + assert len(manager.splits()) == 7 + return + assert len(manager.split_names()) == 9 + assert len(manager.splits()) == 9 -class InMemoryIntegrationTests(object): +class InMemoryDebugIntegrationTests(object): """Inmemory storage-based integration tests.""" def setup_method(self): """Prepare storages with test data.""" - split_storage = InMemorySplitStorage() - segment_storage = InMemorySegmentStorage() + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') with open(split_fn, 'r') as flo: data = json.loads(flo.read()) - for split in data['splits']: - split_storage.put(splits.from_raw(split)) + for split in data['ff']['d']: + split_storage.update([splits.from_raw(split)], [], 0) + + for rbs in data['rbs']['d']: + rb_segment_storage.update([rule_based_segments.from_raw(rbs)], [], 0) segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') with open(segment_fn, 'r') as flo: @@ -43,15 +545,39 @@ def setup_method(self): data = json.loads(flo.read()) segment_storage.put(segments.from_raw(data)) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + storages = { 'splits': split_storage, 'segments': segment_storage, - 'impressions': InMemoryImpressionStorage(5000), - 'events': InMemoryEventStorage(5000), + 'rule_based_segments': rb_segment_storage, + 'impressions': InMemoryImpressionStorage(5000, telemetry_runtime_producer), + 'events': InMemoryEventStorage(5000, telemetry_runtime_producer), } - impmanager = ImpressionsManager(storages['impressions'].put, ImpressionsMode.DEBUG) - recorder = StandardRecorder(impmanager, storages['events'], storages['impressions']) - self.factory = SplitFactory('some_api_key', storages, True, recorder) # pylint:disable=attribute-defined-outside-init + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorder(impmanager, storages['events'], storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, imp_counter=ImpressionsCounter()) + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + internal_events_task = EventsTask(events_manager.notify_internal_event, events_queue) + + # Since we are passing None as SDK_Ready event, the factory will use the Redis telemetry call, using try catch to ignore the exception. + try: + self.factory = SplitFactory('some_api_key', + storages, + True, + recorder, + events_queue, + events_manager, + None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + internal_events_task.start() + except: + pass def teardown_method(self): """Shut down the factory.""" @@ -59,114 +585,18 @@ def teardown_method(self): self.factory.destroy(event) event.wait() - def _validate_last_impressions(self, client, *to_validate): - """Validate the last N impressions are present disregarding the order.""" - imp_storage = client._factory._get_storage('impressions') - impressions = imp_storage.pop_many(len(to_validate)) - as_tup_set = set((i.feature_name, i.matching_key, i.treatment) for i in impressions) - assert as_tup_set == set(to_validate) - def test_get_treatment(self): """Test client.get_treatment().""" - client = self.factory.client() - - assert client.get_treatment('user1', 'sample_feature') == 'on' - self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) - - assert client.get_treatment('invalidKey', 'sample_feature') == 'off' - self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) - - assert client.get_treatment('invalidKey', 'invalid_feature') == 'control' - self._validate_last_impressions(client) # No impressions should be present - - # testing a killed feature. No matter what the key, must return default treatment - assert client.get_treatment('invalidKey', 'killed_feature') == 'defTreatment' - self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) - - # testing ALL matcher - assert client.get_treatment('invalidKey', 'all_feature') == 'on' - self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) - - # testing WHITELIST matcher - assert client.get_treatment('whitelisted_user', 'whitelist_feature') == 'on' - self._validate_last_impressions(client, ('whitelist_feature', 'whitelisted_user', 'on')) - assert client.get_treatment('unwhitelisted_user', 'whitelist_feature') == 'off' - self._validate_last_impressions(client, ('whitelist_feature', 'unwhitelisted_user', 'off')) - - # testing INVALID matcher - assert client.get_treatment('some_user_key', 'invalid_matcher_feature') == 'control' - self._validate_last_impressions(client) # No impressions should be present - - # testing Dependency matcher - assert client.get_treatment('somekey', 'dependency_test') == 'off' - self._validate_last_impressions(client, ('dependency_test', 'somekey', 'off')) - - # testing boolean matcher - assert client.get_treatment('True', 'boolean_test') == 'on' - self._validate_last_impressions(client, ('boolean_test', 'True', 'on')) - - # testing regex matcher - assert client.get_treatment('abc4', 'regex_test') == 'on' - self._validate_last_impressions(client, ('regex_test', 'abc4', 'on')) + _get_treatment(self.factory) def test_get_treatment_with_config(self): """Test client.get_treatment_with_config().""" - client = self.factory.client() - - result = client.get_treatment_with_config('user1', 'sample_feature') - assert result == ('on', '{"size":15,"test":20}') - self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) - - result = client.get_treatment_with_config('invalidKey', 'sample_feature') - assert result == ('off', None) - self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) - - result = client.get_treatment_with_config('invalidKey', 'invalid_feature') - assert result == ('control', None) - self._validate_last_impressions(client) - - # testing a killed feature. No matter what the key, must return default treatment - result = client.get_treatment_with_config('invalidKey', 'killed_feature') - assert ('defTreatment', '{"size":15,"defTreatment":true}') == result - self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) - - # testing ALL matcher - result = client.get_treatment_with_config('invalidKey', 'all_feature') - assert result == ('on', None) - self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + _get_treatment_with_config(self.factory) def test_get_treatments(self): - """Test client.get_treatments().""" + _get_treatments(self.factory) + # testing multiple splitNames client = self.factory.client() - - result = client.get_treatments('user1', ['sample_feature']) - assert len(result) == 1 - assert result['sample_feature'] == 'on' - self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) - - result = client.get_treatments('invalidKey', ['sample_feature']) - assert len(result) == 1 - assert result['sample_feature'] == 'off' - self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) - - result = client.get_treatments('invalidKey', ['invalid_feature']) - assert len(result) == 1 - assert result['invalid_feature'] == 'control' - self._validate_last_impressions(client) - - # testing a killed feature. No matter what the key, must return default treatment - result = client.get_treatments('invalidKey', ['killed_feature']) - assert len(result) == 1 - assert result['killed_feature'] == 'defTreatment' - self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) - - # testing ALL matcher - result = client.get_treatments('invalidKey', ['all_feature']) - assert len(result) == 1 - assert result['all_feature'] == 'on' - self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) - - # testing multiple splitNames result = client.get_treatments('invalidKey', [ 'all_feature', 'killed_feature', @@ -178,7 +608,7 @@ def test_get_treatments(self): assert result['killed_feature'] == 'defTreatment' assert result['invalid_feature'] == 'control' assert result['sample_feature'] == 'off' - self._validate_last_impressions( + _validate_last_impressions( client, ('all_feature', 'invalidKey', 'on'), ('killed_feature', 'invalidKey', 'defTreatment'), @@ -187,36 +617,9 @@ def test_get_treatments(self): def test_get_treatments_with_config(self): """Test client.get_treatments_with_config().""" - client = self.factory.client() - - result = client.get_treatments_with_config('user1', ['sample_feature']) - assert len(result) == 1 - assert result['sample_feature'] == ('on', '{"size":15,"test":20}') - self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) - - result = client.get_treatments_with_config('invalidKey', ['sample_feature']) - assert len(result) == 1 - assert result['sample_feature'] == ('off', None) - self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) - - result = client.get_treatments_with_config('invalidKey', ['invalid_feature']) - assert len(result) == 1 - assert result['invalid_feature'] == ('control', None) - self._validate_last_impressions(client) - - # testing a killed feature. No matter what the key, must return default treatment - result = client.get_treatments_with_config('invalidKey', ['killed_feature']) - assert len(result) == 1 - assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') - self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) - - # testing ALL matcher - result = client.get_treatments_with_config('invalidKey', ['all_feature']) - assert len(result) == 1 - assert result['all_feature'] == ('on', None) - self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) - + _get_treatments_with_config(self.factory) # testing multiple splitNames + client = self.factory.client() result = client.get_treatments_with_config('invalidKey', [ 'all_feature', 'killed_feature', @@ -228,43 +631,58 @@ def test_get_treatments_with_config(self): assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') assert result['invalid_feature'] == ('control', None) assert result['sample_feature'] == ('off', None) - self._validate_last_impressions( + _validate_last_impressions( client, ('all_feature', 'invalidKey', 'on'), ('killed_feature', 'invalidKey', 'defTreatment'), ('sample_feature', 'invalidKey', 'off'), ) + def test_get_treatments_by_flag_set(self): + """Test client.get_treatments_by_flag_set().""" + _get_treatments_by_flag_set(self.factory) + + def test_get_treatments_by_flag_sets(self): + """Test client.get_treatments_by_flag_sets().""" + _get_treatments_by_flag_sets(self.factory) + client = self.factory.client() + result = client.get_treatments_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': 'on', + 'whitelist_feature': 'off', + 'all_feature': 'on' + } + _validate_last_impressions(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + + def test_get_treatments_with_config_by_flag_set(self): + """Test client.get_treatments_with_config_by_flag_set().""" + _get_treatments_with_config_by_flag_set(self.factory) + + def test_get_treatments_with_config_by_flag_sets(self): + """Test client.get_treatments_with_config_by_flag_sets().""" + _get_treatments_with_config_by_flag_sets(self.factory) + client = self.factory.client() + result = client.get_treatments_with_config_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), + 'whitelist_feature': ('off', None), + 'all_feature': ('on', None) + } + _validate_last_impressions(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + + def test_track(self): + """Test client.track().""" + _track(self.factory) + def test_manager_methods(self): """Test manager.split/splits.""" - manager = self.factory.manager() - result = manager.split('all_feature') - assert result.name == 'all_feature' - assert result.traffic_type is None - assert result.killed is False - assert len(result.treatments) == 2 - assert result.change_number == 123 - assert result.configs == {} - - result = manager.split('killed_feature') - assert result.name == 'killed_feature' - assert result.traffic_type is None - assert result.killed is True - assert len(result.treatments) == 2 - assert result.change_number == 123 - assert result.configs['defTreatment'] == '{"size":15,"defTreatment":true}' - assert result.configs['off'] == '{"size":15,"test":20}' - - result = manager.split('sample_feature') - assert result.name == 'sample_feature' - assert result.traffic_type is None - assert result.killed is False - assert len(result.treatments) == 2 - assert result.change_number == 123 - assert result.configs['on'] == '{"size":15,"test":20}' - - assert len(manager.split_names()) == 7 - assert len(manager.splits()) == 7 + _manager_methods(self.factory) class InMemoryOptimizedIntegrationTests(object): @@ -272,14 +690,18 @@ class InMemoryOptimizedIntegrationTests(object): def setup_method(self): """Prepare storages with test data.""" - split_storage = InMemorySplitStorage() - segment_storage = InMemorySegmentStorage() - + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') with open(split_fn, 'r') as flo: data = json.loads(flo.read()) - for split in data['splits']: - split_storage.put(splits.from_raw(split)) + for split in data['ff']['d']: + split_storage.update([splits.from_raw(split)], [], 0) + + for rbs in data['rbs']['d']: + rb_segment_storage.update([rule_based_segments.from_raw(rbs)], [], 0) segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') with open(segment_fn, 'r') as flo: @@ -291,104 +713,186 @@ def setup_method(self): data = json.loads(flo.read()) segment_storage.put(segments.from_raw(data)) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + storages = { 'splits': split_storage, 'segments': segment_storage, - 'impressions': InMemoryImpressionStorage(5000), - 'events': InMemoryEventStorage(5000), + 'rule_based_segments': rb_segment_storage, + 'impressions': InMemoryImpressionStorage(5000, telemetry_runtime_producer), + 'events': InMemoryEventStorage(5000, telemetry_runtime_producer), } - impmanager = ImpressionsManager(ImpressionsMode.OPTIMIZED, True) - recorder = StandardRecorder(impmanager, storages['events'], storages['impressions']) - self.factory = SplitFactory('some_api_key', storages, True, recorder) # pylint:disable=attribute-defined-outside-init - - def _validate_last_impressions(self, client, *to_validate): - """Validate the last N impressions are present disregarding the order.""" - imp_storage = client._factory._get_storage('impressions') - impressions = imp_storage.pop_many(len(to_validate)) - as_tup_set = set((i.feature_name, i.matching_key, i.treatment) for i in impressions) - assert as_tup_set == set(to_validate) + impmanager = ImpressionsManager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorder(impmanager, storages['events'], storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, imp_counter=ImpressionsCounter()) + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + internal_events_task = EventsTask(events_manager.notify_internal_event, events_queue) + self.factory = SplitFactory('some_api_key', + storages, + True, + recorder, + events_queue, + events_manager, + None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + internal_events_task.start() def test_get_treatment(self): """Test client.get_treatment().""" + _get_treatment(self.factory) + + def test_get_treatments(self): + """Test client.get_treatments().""" + _get_treatments(self.factory) + # testing multiple splitNames + client = self.factory.client() + result = client.get_treatments('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == 'on' + assert result['killed_feature'] == 'defTreatment' + assert result['invalid_feature'] == 'control' + assert result['sample_feature'] == 'off' + assert self.factory._storages['impressions']._impressions.qsize() == 0 + + def test_get_treatments_with_config(self): + """Test client.get_treatments_with_config().""" + _get_treatments_with_config(self.factory) + # testing multiple splitNames client = self.factory.client() + result = client.get_treatments_with_config('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == ('on', None) + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + assert result['invalid_feature'] == ('control', None) + assert result['sample_feature'] == ('off', None) + _validate_last_impressions(client,) - assert client.get_treatment('user1', 'sample_feature') == 'on' - self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) - client.get_treatment('user1', 'sample_feature') - client.get_treatment('user1', 'sample_feature') - client.get_treatment('user1', 'sample_feature') + def test_get_treatments_by_flag_set(self): + """Test client.get_treatments_by_flag_set().""" + _get_treatments_by_flag_set(self.factory) - # Only one impression was added, and popped when validating, the rest were ignored + def test_get_treatments_by_flag_sets(self): + """Test client.get_treatments_by_flag_sets().""" + _get_treatments_by_flag_sets(self.factory) + client = self.factory.client() + result = client.get_treatments_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': 'on', + 'whitelist_feature': 'off', + 'all_feature': 'on' + } + _validate_last_impressions(client, ) assert self.factory._storages['impressions']._impressions.qsize() == 0 - assert client.get_treatment('invalidKey', 'sample_feature') == 'off' - self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) - - assert client.get_treatment('invalidKey', 'invalid_feature') == 'control' - self._validate_last_impressions(client) # No impressions should be present + def test_get_treatments_with_config_by_flag_set(self): + """Test client.get_treatments_with_config_by_flag_set().""" + _get_treatments_with_config_by_flag_set(self.factory) - # testing a killed feature. No matter what the key, must return default treatment - assert client.get_treatment('invalidKey', 'killed_feature') == 'defTreatment' - self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + def test_get_treatments_with_config_by_flag_sets(self): + """Test client.get_treatments_with_config_by_flag_sets().""" + _get_treatments_with_config_by_flag_sets(self.factory) + client = self.factory.client() + result = client.get_treatments_with_config_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), + 'whitelist_feature': ('off', None), + 'all_feature': ('on', None) + } + _validate_last_impressions(client, ) - # testing ALL matcher - assert client.get_treatment('invalidKey', 'all_feature') == 'on' - self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + def test_manager_methods(self): + """Test manager.split/splits.""" + _manager_methods(self.factory) - # testing WHITELIST matcher - assert client.get_treatment('whitelisted_user', 'whitelist_feature') == 'on' - self._validate_last_impressions(client, ('whitelist_feature', 'whitelisted_user', 'on')) - assert client.get_treatment('unwhitelisted_user', 'whitelist_feature') == 'off' - self._validate_last_impressions(client, ('whitelist_feature', 'unwhitelisted_user', 'off')) + def test_track(self): + """Test client.track().""" + _track(self.factory) - # testing INVALID matcher - assert client.get_treatment('some_user_key', 'invalid_matcher_feature') == 'control' - self._validate_last_impressions(client) # No impressions should be present +class InMemoryOldSpecIntegrationTests(object): + """Inmemory storage-based integration tests.""" - # testing Dependency matcher - assert client.get_treatment('somekey', 'dependency_test') == 'off' - self._validate_last_impressions(client, ('dependency_test', 'somekey', 'off')) + def setup_method(self): + """Prepare storages with test data.""" - # testing boolean matcher - assert client.get_treatment('True', 'boolean_test') == 'on' - self._validate_last_impressions(client, ('boolean_test', 'True', 'on')) + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'split_old_spec.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + + split_changes = { + -1: data, + 1457726098069: {"splits": [], "till": 1457726098069, "since": 1457726098069} + } - # testing regex matcher - assert client.get_treatment('abc4', 'regex_test') == 'on' - self._validate_last_impressions(client, ('regex_test', 'abc4', 'on')) + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + segment_employee = json.loads(flo.read()) - def test_get_treatments(self): - """Test client.get_treatments().""" - client = self.factory.client() + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + segment_human = json.loads(flo.read()) + + segment_changes = { + ("employees", -1): segment_employee, + ("employees", 1457474612832): {"name": "employees","added": [],"removed": [],"since": 1457474612832,"till": 1457474612832}, + ("human_beigns", -1): segment_human, + ("human_beigns", 1457102183278): {"name": "employees","added": [],"removed": [],"since": 1457102183278,"till": 1457102183278}, + } - result = client.get_treatments('user1', ['sample_feature']) - assert len(result) == 1 - assert result['sample_feature'] == 'on' - self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) + split_backend_requests = Queue() + self.split_backend = SplitMockServer(split_changes, segment_changes, split_backend_requests, + {'auth_response': {'pushEnabled': False}}, True) + self.split_backend.start() + + kwargs = { + 'sdk_api_base_url': 'http://localhost:%d/api' % self.split_backend.port(), + 'events_api_base_url': 'http://localhost:%d/api' % self.split_backend.port(), + 'auth_api_base_url': 'http://localhost:%d/api' % self.split_backend.port(), + 'config': {'connectTimeout': 10000, + 'streamingEnabled': False, + 'impressionsMode': 'debug', + 'fallbackTreatments': FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')}) + } + } - result = client.get_treatments('invalidKey', ['sample_feature']) - assert len(result) == 1 - assert result['sample_feature'] == 'off' - self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) + self.factory = get_factory('some_apikey', **kwargs) + self.factory.block_until_ready(1) + assert self.factory.ready - result = client.get_treatments('invalidKey', ['invalid_feature']) - assert len(result) == 1 - assert result['invalid_feature'] == 'control' - self._validate_last_impressions(client) + def teardown_method(self): + """Shut down the factory.""" + event = threading.Event() + self.factory.destroy(event) + event.wait() + self.split_backend.stop() + time.sleep(1) - # testing a killed feature. No matter what the key, must return default treatment - result = client.get_treatments('invalidKey', ['killed_feature']) - assert len(result) == 1 - assert result['killed_feature'] == 'defTreatment' - self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + def test_get_treatment(self): + """Test client.get_treatment().""" + _get_treatment(self.factory, True) - # testing ALL matcher - result = client.get_treatments('invalidKey', ['all_feature']) - assert len(result) == 1 - assert result['all_feature'] == 'on' - self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + def test_get_treatment_with_config(self): + """Test client.get_treatment_with_config().""" + _get_treatment_with_config(self.factory) - # testing multiple splitNames + def test_get_treatments(self): + _get_treatments(self.factory) + # testing multiple splitNames + client = self.factory.client() result = client.get_treatments('invalidKey', [ 'all_feature', 'killed_feature', @@ -400,40 +904,18 @@ def test_get_treatments(self): assert result['killed_feature'] == 'defTreatment' assert result['invalid_feature'] == 'control' assert result['sample_feature'] == 'off' - assert self.factory._storages['impressions']._impressions.qsize() == 0 + _validate_last_impressions( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('sample_feature', 'invalidKey', 'off') + ) def test_get_treatments_with_config(self): """Test client.get_treatments_with_config().""" - client = self.factory.client() - - result = client.get_treatments_with_config('user1', ['sample_feature']) - assert len(result) == 1 - assert result['sample_feature'] == ('on', '{"size":15,"test":20}') - self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) - - result = client.get_treatments_with_config('invalidKey', ['sample_feature']) - assert len(result) == 1 - assert result['sample_feature'] == ('off', None) - self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) - - result = client.get_treatments_with_config('invalidKey', ['invalid_feature']) - assert len(result) == 1 - assert result['invalid_feature'] == ('control', None) - self._validate_last_impressions(client) - - # testing a killed feature. No matter what the key, must return default treatment - result = client.get_treatments_with_config('invalidKey', ['killed_feature']) - assert len(result) == 1 - assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') - self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) - - # testing ALL matcher - result = client.get_treatments_with_config('invalidKey', ['all_feature']) - assert len(result) == 1 - assert result['all_feature'] == ('on', None) - self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) - + _get_treatments_with_config(self.factory) # testing multiple splitNames + client = self.factory.client() result = client.get_treatments_with_config('invalidKey', [ 'all_feature', 'killed_feature', @@ -441,45 +923,63 @@ def test_get_treatments_with_config(self): 'sample_feature' ]) assert len(result) == 4 - assert result['all_feature'] == ('on', None) assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') assert result['invalid_feature'] == ('control', None) assert result['sample_feature'] == ('off', None) - assert self.factory._storages['impressions']._impressions.qsize() == 0 - - def test_manager_methods(self): - """Test manager.split/splits.""" - manager = self.factory.manager() - result = manager.split('all_feature') - assert result.name == 'all_feature' - assert result.traffic_type is None - assert result.killed is False - assert len(result.treatments) == 2 - assert result.change_number == 123 - assert result.configs == {} - - result = manager.split('killed_feature') - assert result.name == 'killed_feature' - assert result.traffic_type is None - assert result.killed is True - assert len(result.treatments) == 2 - assert result.change_number == 123 - assert result.configs['defTreatment'] == '{"size":15,"defTreatment":true}' - assert result.configs['off'] == '{"size":15,"test":20}' - - result = manager.split('sample_feature') - assert result.name == 'sample_feature' - assert result.traffic_type is None - assert result.killed is False - assert len(result.treatments) == 2 - assert result.change_number == 123 - assert result.configs['on'] == '{"size":15,"test":20}' + _validate_last_impressions( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('sample_feature', 'invalidKey', 'off'), + ) - assert len(manager.split_names()) == 7 - assert len(manager.splits()) == 7 + def test_get_treatments_by_flag_set(self): + """Test client.get_treatments_by_flag_set().""" + _get_treatments_by_flag_set(self.factory) + def test_get_treatments_by_flag_sets(self): + """Test client.get_treatments_by_flag_sets().""" + _get_treatments_by_flag_sets(self.factory) + client = self.factory.client() + result = client.get_treatments_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': 'on', + 'whitelist_feature': 'off', + 'all_feature': 'on' + } + _validate_last_impressions(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + + def test_get_treatments_with_config_by_flag_set(self): + """Test client.get_treatments_with_config_by_flag_set().""" + _get_treatments_with_config_by_flag_set(self.factory) + + def test_get_treatments_with_config_by_flag_sets(self): + """Test client.get_treatments_with_config_by_flag_sets().""" + _get_treatments_with_config_by_flag_sets(self.factory) + client = self.factory.client() + result = client.get_treatments_with_config_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), + 'whitelist_feature': ('off', None), + 'all_feature': ('on', None) + } + _validate_last_impressions(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + + def test_track(self): + """Test client.track().""" + _track(self.factory) + def test_manager_methods(self): + """Test manager.split/splits.""" + _manager_methods(self.factory, True) + class RedisIntegrationTests(object): """Redis storage-based integration tests.""" @@ -489,13 +989,20 @@ def setup_method(self): redis_client = build(DEFAULT_CONFIG.copy()) split_storage = RedisSplitStorage(redis_client) segment_storage = RedisSegmentStorage(redis_client) + rb_segment_storage = RedisRuleBasedSegmentsStorage(redis_client) split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') with open(split_fn, 'r') as flo: data = json.loads(flo.read()) - for split in data['splits']: + for split in data['ff']['d']: redis_client.set(split_storage._get_key(split['name']), json.dumps(split)) - redis_client.set(split_storage._SPLIT_TILL_KEY, data['till']) + if split.get('sets') is not None: + for flag_set in split.get('sets'): + redis_client.sadd(split_storage._get_flag_set_key(flag_set), split['name']) + redis_client.set(split_storage._FEATURE_FLAG_TILL_KEY, data['ff']['t']) + + for rbs in data['rbs']['d']: + redis_client.set(rb_segment_storage._get_key(rbs['name']), json.dumps(rbs)) segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') with open(segment_fn, 'r') as flo: @@ -509,132 +1016,45 @@ def setup_method(self): redis_client.sadd(segment_storage._get_key(data['name']), *data['added']) redis_client.set(segment_storage._get_till_key(data['name']), data['till']) + telemetry_redis_storage = RedisTelemetryStorage(redis_client, metadata) + telemetry_producer = TelemetryStorageProducer(telemetry_redis_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + storages = { 'splits': split_storage, 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, 'impressions': RedisImpressionsStorage(redis_client, metadata), 'events': RedisEventsStorage(redis_client, metadata), } - impmanager = ImpressionsManager(ImpressionsMode.DEBUG, False) - recorder = PipelinedRecorder(redis_client.pipeline, impmanager, - storages['events'], storages['impressions']) - self.factory = SplitFactory('some_api_key', storages, True, recorder) # pylint:disable=attribute-defined-outside-init - - def _validate_last_impressions(self, client, *to_validate): - """Validate the last N impressions are present disregarding the order.""" - imp_storage = client._factory._get_storage('impressions') - redis_client = imp_storage._redis - impressions_raw = [ - json.loads(redis_client.lpop(imp_storage.IMPRESSIONS_QUEUE_KEY)) - for _ in to_validate - ] - as_tup_set = set( - (i['i']['f'], i['i']['k'], i['i']['t']) - for i in impressions_raw - ) - - assert as_tup_set == set(to_validate) + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = PipelinedRecorder(redis_client.pipeline, impmanager, storages['events'], + storages['impressions'], telemetry_redis_storage, imp_counter=ImpressionsCounter()) + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + events_queue = queue.Queue() + self.factory = SplitFactory('some_api_key', + storages, + True, + recorder, + events_queue, + events_manager, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init def test_get_treatment(self): """Test client.get_treatment().""" - client = self.factory.client() - - assert client.get_treatment('user1', 'sample_feature') == 'on' - self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) - - assert client.get_treatment('invalidKey', 'sample_feature') == 'off' - self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) - - assert client.get_treatment('invalidKey', 'invalid_feature') == 'control' - self._validate_last_impressions(client) - - # testing a killed feature. No matter what the key, must return default treatment - assert client.get_treatment('invalidKey', 'killed_feature') == 'defTreatment' - self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) - - # testing ALL matcher - assert client.get_treatment('invalidKey', 'all_feature') == 'on' - self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) - - # testing WHITELIST matcher - assert client.get_treatment('whitelisted_user', 'whitelist_feature') == 'on' - self._validate_last_impressions(client, ('whitelist_feature', 'whitelisted_user', 'on')) - assert client.get_treatment('unwhitelisted_user', 'whitelist_feature') == 'off' - self._validate_last_impressions(client, ('whitelist_feature', 'unwhitelisted_user', 'off')) - - # testing INVALID matcher - assert client.get_treatment('some_user_key', 'invalid_matcher_feature') == 'control' - self._validate_last_impressions(client) - - # testing Dependency matcher - assert client.get_treatment('somekey', 'dependency_test') == 'off' - self._validate_last_impressions(client, ('dependency_test', 'somekey', 'off')) - - # testing boolean matcher - assert client.get_treatment('True', 'boolean_test') == 'on' - self._validate_last_impressions(client, ('boolean_test', 'True', 'on')) - - # testing regex matcher - assert client.get_treatment('abc4', 'regex_test') == 'on' - self._validate_last_impressions(client, ('regex_test', 'abc4', 'on')) + _get_treatment(self.factory) def test_get_treatment_with_config(self): """Test client.get_treatment_with_config().""" - client = self.factory.client() - - result = client.get_treatment_with_config('user1', 'sample_feature') - assert result == ('on', '{"size":15,"test":20}') - self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) - - result = client.get_treatment_with_config('invalidKey', 'sample_feature') - assert result == ('off', None) - self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) - - result = client.get_treatment_with_config('invalidKey', 'invalid_feature') - assert result == ('control', None) - self._validate_last_impressions(client) - - # testing a killed feature. No matter what the key, must return default treatment - result = client.get_treatment_with_config('invalidKey', 'killed_feature') - assert ('defTreatment', '{"size":15,"defTreatment":true}') == result - self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) - - # testing ALL matcher - result = client.get_treatment_with_config('invalidKey', 'all_feature') - assert result == ('on', None) - self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + _get_treatment_with_config(self.factory) def test_get_treatments(self): """Test client.get_treatments().""" + _get_treatments(self.factory) client = self.factory.client() - - result = client.get_treatments('user1', ['sample_feature']) - assert len(result) == 1 - assert result['sample_feature'] == 'on' - self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) - - result = client.get_treatments('invalidKey', ['sample_feature']) - assert len(result) == 1 - assert result['sample_feature'] == 'off' - self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) - - result = client.get_treatments('invalidKey', ['invalid_feature']) - assert len(result) == 1 - assert result['invalid_feature'] == 'control' - self._validate_last_impressions(client) - - # testing a killed feature. No matter what the key, must return default treatment - result = client.get_treatments('invalidKey', ['killed_feature']) - assert len(result) == 1 - assert result['killed_feature'] == 'defTreatment' - self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) - - # testing ALL matcher - result = client.get_treatments('invalidKey', ['all_feature']) - assert len(result) == 1 - assert result['all_feature'] == 'on' - self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) - # testing multiple splitNames result = client.get_treatments('invalidKey', [ 'all_feature', @@ -647,44 +1067,21 @@ def test_get_treatments(self): assert result['killed_feature'] == 'defTreatment' assert result['invalid_feature'] == 'control' assert result['sample_feature'] == 'off' - self._validate_last_impressions( + _validate_last_impressions( client, ('all_feature', 'invalidKey', 'on'), ('killed_feature', 'invalidKey', 'defTreatment'), ('sample_feature', 'invalidKey', 'off') ) + def test_get_treatment_with_config(self): + """Test client.get_treatment_with_config().""" + _get_treatment_with_config(self.factory) + def test_get_treatments_with_config(self): """Test client.get_treatments_with_config().""" + _get_treatments_with_config(self.factory) client = self.factory.client() - - result = client.get_treatments_with_config('user1', ['sample_feature']) - assert len(result) == 1 - assert result['sample_feature'] == ('on', '{"size":15,"test":20}') - self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) - - result = client.get_treatments_with_config('invalidKey', ['sample_feature']) - assert len(result) == 1 - assert result['sample_feature'] == ('off', None) - self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) - - result = client.get_treatments_with_config('invalidKey', ['invalid_feature']) - assert len(result) == 1 - assert result['invalid_feature'] == ('control', None) - self._validate_last_impressions(client) - - # testing a killed feature. No matter what the key, must return default treatment - result = client.get_treatments_with_config('invalidKey', ['killed_feature']) - assert len(result) == 1 - assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') - self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) - - # testing ALL matcher - result = client.get_treatments_with_config('invalidKey', ['all_feature']) - assert len(result) == 1 - assert result['all_feature'] == ('on', None) - self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) - # testing multiple splitNames result = client.get_treatments_with_config('invalidKey', [ 'all_feature', @@ -697,43 +1094,58 @@ def test_get_treatments_with_config(self): assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') assert result['invalid_feature'] == ('control', None) assert result['sample_feature'] == ('off', None) - self._validate_last_impressions( + _validate_last_impressions( client, ('all_feature', 'invalidKey', 'on'), ('killed_feature', 'invalidKey', 'defTreatment'), ('sample_feature', 'invalidKey', 'off'), ) + def test_get_treatments_by_flag_set(self): + """Test client.get_treatments_by_flag_set().""" + _get_treatments_by_flag_set(self.factory) + + def test_get_treatments_by_flag_sets(self): + """Test client.get_treatments_by_flag_sets().""" + _get_treatments_by_flag_sets(self.factory) + client = self.factory.client() + result = client.get_treatments_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': 'on', + 'whitelist_feature': 'off', + 'all_feature': 'on' + } + _validate_last_impressions(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + + def test_get_treatments_with_config_by_flag_set(self): + """Test client.get_treatments_with_config_by_flag_set().""" + _get_treatments_with_config_by_flag_set(self.factory) + + def test_get_treatments_with_config_by_flag_sets(self): + """Test client.get_treatments_with_config_by_flag_sets().""" + _get_treatments_with_config_by_flag_sets(self.factory) + client = self.factory.client() + result = client.get_treatments_with_config_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), + 'whitelist_feature': ('off', None), + 'all_feature': ('on', None) + } + _validate_last_impressions(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + + def test_track(self): + """Test client.track().""" + _track(self.factory) + def test_manager_methods(self): """Test manager.split/splits.""" - manager = self.factory.manager() - result = manager.split('all_feature') - assert result.name == 'all_feature' - assert result.traffic_type is None - assert result.killed is False - assert len(result.treatments) == 2 - assert result.change_number == 123 - assert result.configs == {} - - result = manager.split('killed_feature') - assert result.name == 'killed_feature' - assert result.traffic_type is None - assert result.killed is True - assert len(result.treatments) == 2 - assert result.change_number == 123 - assert result.configs['defTreatment'] == '{"size":15,"defTreatment":true}' - assert result.configs['off'] == '{"size":15,"test":20}' - - result = manager.split('sample_feature') - assert result.name == 'sample_feature' - assert result.traffic_type is None - assert result.killed is False - assert len(result.treatments) == 2 - assert result.change_number == 123 - assert result.configs['on'] == '{"size":15,"test":20}' - - assert len(manager.split_names()) == 7 - assert len(manager.splits()) == 7 + _manager_methods(self.factory) def teardown_method(self): """Clear redis cache.""" @@ -749,14 +1161,20 @@ def teardown_method(self): "SPLITIO.split.regex_test", "SPLITIO.segment.human_beigns.till", "SPLITIO.split.boolean_test", - "SPLITIO.split.dependency_test" + "SPLITIO.split.dependency_test", + "SPLITIO.split.set.set1", + "SPLITIO.split.set.set2", + "SPLITIO.split.set.set3", + "SPLITIO.split.set.set4", + "SPLITIO.split.rbs_feature_flag", + "SPLITIO.rbsegments.till", + "SPLITIO.rbsegments.sample_rule_based_segment" ] redis_client = RedisAdapter(StrictRedis()) for key in keys_to_delete: redis_client.delete(key) - class RedisWithCacheIntegrationTests(RedisIntegrationTests): """Run the same tests as RedisIntegratioTests but with LRU/Expirable cache overlay.""" @@ -766,13 +1184,17 @@ def setup_method(self): redis_client = build(DEFAULT_CONFIG.copy()) split_storage = RedisSplitStorage(redis_client, True) segment_storage = RedisSegmentStorage(redis_client) + rb_segment_storage = RedisRuleBasedSegmentsStorage(redis_client) split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') with open(split_fn, 'r') as flo: data = json.loads(flo.read()) - for split in data['splits']: + for split in data['ff']['d']: redis_client.set(split_storage._get_key(split['name']), json.dumps(split)) - redis_client.set(split_storage._SPLIT_TILL_KEY, data['till']) + redis_client.set(split_storage._FEATURE_FLAG_TILL_KEY, data['ff']['t']) + + for rbs in data['rbs']['d']: + redis_client.set(rb_segment_storage._get_key(rbs['name']), json.dumps(rbs)) segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') with open(segment_fn, 'r') as flo: @@ -786,21 +1208,189 @@ def setup_method(self): redis_client.sadd(segment_storage._get_key(data['name']), *data['added']) redis_client.set(segment_storage._get_till_key(data['name']), data['till']) + telemetry_redis_storage = RedisTelemetryStorage(redis_client, metadata) + telemetry_producer = TelemetryStorageProducer(telemetry_redis_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + storages = { 'splits': split_storage, 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, 'impressions': RedisImpressionsStorage(redis_client, metadata), 'events': RedisEventsStorage(redis_client, metadata), } - impmanager = ImpressionsManager(storages['impressions'].put, ImpressionsMode.DEBUG) + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener recorder = PipelinedRecorder(redis_client.pipeline, impmanager, - storages['events'], storages['impressions']) - self.factory = SplitFactory('some_api_key', storages, True, recorder) # pylint:disable=attribute-defined-outside-init - - + storages['events'], storages['impressions'], telemetry_redis_storage, imp_counter=ImpressionsCounter()) + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + events_queue = queue.Queue() + self.factory = SplitFactory('some_api_key', + storages, + True, + recorder, + events_queue, + events_manager, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + class LocalhostIntegrationTests(object): # pylint: disable=too-few-public-methods """Client & Manager integration tests.""" + def test_localhost_json_e2e(self): + """Instantiate a client with a JSON file and issue get_treatment() calls.""" + self._update_temp_file(splits_json['splitChange2_1']) + filename = os.path.join(os.path.dirname(__file__), 'files', 'split_changes_temp.json') + self.factory = get_factory('localhost', config={'splitFile': filename}) + self.factory.block_until_ready(1) + client = self.factory.client() + + # Tests 2 + assert self.factory.manager().split_names() == ["SPLIT_1"] + assert client.get_treatment("key", "SPLIT_1") == 'off' + + # Tests 1 + self.factory._storages['splits'].update([], ['SPLIT_1'], -1) + self._update_temp_file(splits_json['splitChange1_1']) + self._synchronize_now() + + assert sorted(self.factory.manager().split_names()) == ["SPLIT_1", "SPLIT_2", "SPLIT_3"] + assert client.get_treatment("key", "SPLIT_1", None) == 'off' + assert client.get_treatment("key", "SPLIT_2", None) == 'on' + + self._update_temp_file(splits_json['splitChange1_2']) + self._synchronize_now() + + assert sorted(self.factory.manager().split_names()) == ["SPLIT_1", "SPLIT_2", "SPLIT_3"] + assert client.get_treatment("key", "SPLIT_1", None) == 'off' + assert client.get_treatment("key", "SPLIT_2", None) == 'off' + + self._update_temp_file(splits_json['splitChange1_3']) + self._synchronize_now() + + assert self.factory.manager().split_names() == ["SPLIT_2", "SPLIT_3"] + assert client.get_treatment("key", "SPLIT_1", None) == 'control' + assert client.get_treatment("key", "SPLIT_2", None) == 'on' + + # Tests 3 + self.factory._storages['splits'].update([], ['SPLIT_1'], -1) + self._update_temp_file(splits_json['splitChange3_1']) + self._synchronize_now() + + assert self.factory.manager().split_names() == ["SPLIT_2", "SPLIT_3"] + assert client.get_treatment("key", "SPLIT_2", None) == 'on' + + self._update_temp_file(splits_json['splitChange3_2']) + self._synchronize_now() + + assert self.factory.manager().split_names() == ["SPLIT_2", "SPLIT_3"] + assert client.get_treatment("key", "SPLIT_2", None) == 'off' + + # Tests 4 + self.factory._storages['splits'].update([], ['SPLIT_2'], -1) + self._update_temp_file(splits_json['splitChange4_1']) + self._synchronize_now() + + assert sorted(self.factory.manager().split_names()) == ["SPLIT_1", "SPLIT_2", "SPLIT_3"] + assert client.get_treatment("key", "SPLIT_1", None) == 'off' + assert client.get_treatment("key", "SPLIT_2", None) == 'on' + + self._update_temp_file(splits_json['splitChange4_2']) + self._synchronize_now() + + assert sorted(self.factory.manager().split_names()) == ["SPLIT_1", "SPLIT_2", "SPLIT_3"] + assert client.get_treatment("key", "SPLIT_1", None) == 'off' + assert client.get_treatment("key", "SPLIT_2", None) == 'off' + + self._update_temp_file(splits_json['splitChange4_3']) + self._synchronize_now() + + assert sorted(self.factory.manager().split_names()) == ["SPLIT_2", "SPLIT_3"] + assert client.get_treatment("key", "SPLIT_1", None) == 'control' + assert client.get_treatment("key", "SPLIT_2", None) == 'on' + + # Tests 5 + self.factory._storages['splits'].update([], ['SPLIT_1', 'SPLIT_2'], -1) + self._update_temp_file(splits_json['splitChange5_1']) + self._synchronize_now() + + assert sorted(self.factory.manager().split_names()) == ["SPLIT_2", "SPLIT_3"] + assert client.get_treatment("key", "SPLIT_2", None) == 'on' + + self._update_temp_file(splits_json['splitChange5_2']) + self._synchronize_now() + + assert sorted(self.factory.manager().split_names()) == ["SPLIT_2", "SPLIT_3"] + assert client.get_treatment("key", "SPLIT_2", None) == 'on' + + # Tests 6 + self.factory._storages['splits'].update([], ['SPLIT_2'], -1) + self._update_temp_file(splits_json['splitChange6_1']) + self._synchronize_now() + + assert sorted(self.factory.manager().split_names()) == ["SPLIT_1", "SPLIT_2", "SPLIT_3"] + assert client.get_treatment("key", "SPLIT_1", None) == 'off' + assert client.get_treatment("key", "SPLIT_2", None) == 'on' + + self._update_temp_file(splits_json['splitChange6_2']) + self._synchronize_now() + + assert sorted(self.factory.manager().split_names()) == ["SPLIT_1", "SPLIT_2", "SPLIT_3"] + assert client.get_treatment("key", "SPLIT_1", None) == 'off' + assert client.get_treatment("key", "SPLIT_2", None) == 'off' + + self._update_temp_file(splits_json['splitChange6_3']) + self._synchronize_now() + + assert sorted(self.factory.manager().split_names()) == ["SPLIT_2", "SPLIT_3"] + assert client.get_treatment("key", "SPLIT_1", None) == 'control' + assert client.get_treatment("key", "SPLIT_2", None) == 'on' + + # rule based segment test + self._update_temp_file(splits_json['splitChange7_1']) + self._synchronize_now() + assert client.get_treatment('bilal@split.io', 'rbs_feature_flag', {'email': 'bilal@split.io'}) == 'on' + assert client.get_treatment('mauro@split.io', 'rbs_feature_flag', {'email': 'mauro@split.io'}) == 'off' + + def _update_temp_file(self, json_body): + f = open(os.path.join(os.path.dirname(__file__), 'files','split_changes_temp.json'), 'w') + f.write(json.dumps(json_body)) + f.close() + + def _synchronize_now(self): + filename = os.path.join(os.path.dirname(__file__), 'files', 'split_changes_temp.json') + self.factory._sync_manager._synchronizer._split_synchronizers._feature_flag_sync._filename = filename + self.factory._sync_manager._synchronizer._split_synchronizers._feature_flag_sync.synchronize_splits() + + def test_incorrect_file_e2e(self): + """Test initialize factory with a incorrect file name.""" + # TODO: secontion below is removed when legacu use BUR + # legacy and yaml + exception_raised = False + factory = None + try: + factory = get_factory('localhost', config={'splitFile': 'filename'}) + except Exception as e: + exception_raised = True + + assert(exception_raised) + + # json using BUR + factory = get_factory('localhost', config={'splitFile': 'filename.json'}) + exception_raised = False + try: + factory.block_until_ready(1) + except Exception as e: + exception_raised = True + + assert(exception_raised) + + event = threading.Event() + factory.destroy(event) + event.wait() + def test_localhost_e2e(self): """Instantiate a client with a YAML file and issue get_treatment() calls.""" filename = os.path.join(os.path.dirname(__file__), 'files', 'file2.yaml') @@ -830,7 +1420,3941 @@ def test_localhost_e2e(self): factory.destroy(event) event.wait() - # hack to increase isolation and prevent conflicts with other tests - thread = factory._sync_manager._synchronizer._split_tasks.split_task._task._thread - if thread is not None and thread.is_alive(): - thread.join() + def test_fallback_treatments(self): + """Instantiate a client with a JSON file and issue get_treatment() calls.""" + self._update_temp_file(splits_json['splitChange2_1']) + filename = os.path.join(os.path.dirname(__file__), 'files', 'split_changes_temp.json') + factory = get_factory('localhost', + config={ + 'splitFile': filename, + 'fallbackTreatments': FallbackTreatmentsConfiguration("on-global", {'fallback_feature': "on-local"}) + } + ) + factory.block_until_ready(1) + client = factory.client() + + assert client.get_treatment("key", "feature") == "on-global" + assert client.get_treatment("key", "fallback_feature") == "on-local" + + event = threading.Event() + factory.destroy(event) + event.wait() + + +class PluggableIntegrationTests(object): + """Pluggable storage-based integration tests.""" + + def setup_method(self): + """Prepare storages with test data.""" + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + self.pluggable_storage_adapter = StorageMockAdapter() + split_storage = PluggableSplitStorage(self.pluggable_storage_adapter) + segment_storage = PluggableSegmentStorage(self.pluggable_storage_adapter) + rb_segment_storage = PluggableRuleBasedSegmentsStorage(self.pluggable_storage_adapter) + + telemetry_pluggable_storage = PluggableTelemetryStorage(self.pluggable_storage_adapter, metadata) + telemetry_producer = TelemetryStorageProducer(telemetry_pluggable_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': PluggableImpressionsStorage(self.pluggable_storage_adapter, metadata), + 'events': PluggableEventsStorage(self.pluggable_storage_adapter, metadata), + 'telemetry': telemetry_pluggable_storage + } + + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorder(impmanager, storages['events'], + storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, imp_counter=ImpressionsCounter()) + + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + events_queue = queue.Queue() + self.factory = SplitFactory('some_api_key', + storages, + True, + recorder, + events_queue, + events_manager, + RedisManager(PluggableSynchronizer()), + sdk_ready_flag=None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + + # Adding data to storage + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + self.pluggable_storage_adapter.set(split_storage._prefix.format(feature_flag_name=split['name']), split) + if split.get('sets') is not None: + for flag_set in split.get('sets'): + self.pluggable_storage_adapter.push_items(split_storage._flag_set_prefix.format(flag_set=flag_set), split['name']) + self.pluggable_storage_adapter.set(split_storage._feature_flag_till_prefix, data['ff']['t']) + + for rbs in data['rbs']['d']: + self.pluggable_storage_adapter.set(rb_segment_storage._prefix.format(segment_name=rbs['name']), rbs) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + self.pluggable_storage_adapter.set(segment_storage._prefix.format(segment_name=data['name']), set(data['added'])) + self.pluggable_storage_adapter.set(segment_storage._segment_till_prefix.format(segment_name=data['name']), data['till']) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + self.pluggable_storage_adapter.set(segment_storage._prefix.format(segment_name=data['name']), set(data['added'])) + self.pluggable_storage_adapter.set(segment_storage._segment_till_prefix.format(segment_name=data['name']), data['till']) + + def test_get_treatment(self): + """Test client.get_treatment().""" + _get_treatment(self.factory) + + def test_get_treatment_with_config(self): + """Test client.get_treatment_with_config().""" + _get_treatment_with_config(self.factory) + + def test_get_treatments(self): + """Test client.get_treatments().""" + _get_treatments(self.factory) + client = self.factory.client() + # testing multiple splitNames + result = client.get_treatments('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == 'on' + assert result['killed_feature'] == 'defTreatment' + assert result['invalid_feature'] == 'control' + assert result['sample_feature'] == 'off' + _validate_last_impressions( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('sample_feature', 'invalidKey', 'off') + ) + + def test_get_treatment_with_config(self): + """Test client.get_treatment_with_config().""" + _get_treatment_with_config(self.factory) + + def test_get_treatments_with_config(self): + """Test client.get_treatments_with_config().""" + _get_treatments_with_config(self.factory) + client = self.factory.client() + # testing multiple splitNames + result = client.get_treatments_with_config('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == ('on', None) + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + assert result['invalid_feature'] == ('control', None) + assert result['sample_feature'] == ('off', None) + _validate_last_impressions( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('sample_feature', 'invalidKey', 'off'), + ) + + def test_get_treatments_by_flag_set(self): + """Test client.get_treatments_by_flag_set().""" + _get_treatments_by_flag_set(self.factory) + + def test_get_treatments_by_flag_sets(self): + """Test client.get_treatments_by_flag_sets().""" + _get_treatments_by_flag_sets(self.factory) + client = self.factory.client() + result = client.get_treatments_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': 'on', + 'whitelist_feature': 'off', + 'all_feature': 'on' + } + _validate_last_impressions(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + + def test_get_treatments_with_config_by_flag_set(self): + """Test client.get_treatments_with_config_by_flag_set().""" + _get_treatments_with_config_by_flag_set(self.factory) + + def test_get_treatments_with_config_by_flag_sets(self): + """Test client.get_treatments_with_config_by_flag_sets().""" + _get_treatments_with_config_by_flag_sets(self.factory) + client = self.factory.client() + result = client.get_treatments_with_config_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), + 'whitelist_feature': ('off', None), + 'all_feature': ('on', None) + } + _validate_last_impressions(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + + def test_track(self): + """Test client.track().""" + _track(self.factory) + + def test_manager_methods(self): + """Test manager.split/splits.""" + _manager_methods(self.factory) + + def teardown_method(self): + """Clear pluggable cache.""" + keys_to_delete = [ + "SPLITIO.segment.human_beigns", + "SPLITIO.segment.employees.till", + "SPLITIO.split.sample_feature", + "SPLITIO.splits.till", + "SPLITIO.split.killed_feature", + "SPLITIO.split.all_feature", + "SPLITIO.split.whitelist_feature", + "SPLITIO.segment.employees", + "SPLITIO.split.regex_test", + "SPLITIO.segment.human_beigns.till", + "SPLITIO.split.boolean_test", + "SPLITIO.split.dependency_test", + "SPLITIO.split.set.set1", + "SPLITIO.split.set.set2", + "SPLITIO.split.set.set3", + "SPLITIO.split.set.set4", + "SPLITIO.split.rbs_feature_flag", + "SPLITIO.rbsegments.till", + "SPLITIO.rbsegments.sample_rule_based_segment" + ] + for key in keys_to_delete: + self.pluggable_storage_adapter.delete(key) + +class PluggableOptimizedIntegrationTests(object): + """Pluggable storage-based integration tests.""" + + def setup_method(self): + """Prepare storages with test data.""" + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + self.pluggable_storage_adapter = StorageMockAdapter() + split_storage = PluggableSplitStorage(self.pluggable_storage_adapter) + segment_storage = PluggableSegmentStorage(self.pluggable_storage_adapter) + rb_segment_storage = PluggableRuleBasedSegmentsStorage(self.pluggable_storage_adapter) + + telemetry_pluggable_storage = PluggableTelemetryStorage(self.pluggable_storage_adapter, metadata) + telemetry_producer = TelemetryStorageProducer(telemetry_pluggable_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': PluggableImpressionsStorage(self.pluggable_storage_adapter, metadata), + 'events': PluggableEventsStorage(self.pluggable_storage_adapter, metadata), + 'telemetry': telemetry_pluggable_storage + } + + impmanager = ImpressionsManager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorder(impmanager, storages['events'], + storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, imp_counter=ImpressionsCounter()) + + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + events_queue = queue.Queue() + self.factory = SplitFactory('some_api_key', + storages, + True, + recorder, + events_queue, + events_manager, + RedisManager(PluggableSynchronizer()), + sdk_ready_flag=None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + + # Adding data to storage + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + if split.get('sets') is not None: + for flag_set in split.get('sets'): + self.pluggable_storage_adapter.push_items(split_storage._flag_set_prefix.format(flag_set=flag_set), split['name']) + self.pluggable_storage_adapter.set(split_storage._prefix.format(feature_flag_name=split['name']), split) + self.pluggable_storage_adapter.set(split_storage._feature_flag_till_prefix, data['ff']['t']) + + for rbs in data['rbs']['d']: + self.pluggable_storage_adapter.set(rb_segment_storage._prefix.format(segment_name=rbs['name']), rbs) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + self.pluggable_storage_adapter.set(segment_storage._prefix.format(segment_name=data['name']), set(data['added'])) + self.pluggable_storage_adapter.set(segment_storage._segment_till_prefix.format(segment_name=data['name']), data['till']) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + self.pluggable_storage_adapter.set(segment_storage._prefix.format(segment_name=data['name']), set(data['added'])) + self.pluggable_storage_adapter.set(segment_storage._segment_till_prefix.format(segment_name=data['name']), data['till']) + + def test_get_treatment(self): + """Test client.get_treatment().""" + _get_treatment(self.factory) + client = self.factory.client() + + assert client.get_treatment('user1', 'sample_feature') == 'on' + client.get_treatment('user1', 'sample_feature') + client.get_treatment('user1', 'sample_feature') + client.get_treatment('user1', 'sample_feature') + assert len(self.pluggable_storage_adapter._keys['SPLITIO.impressions']) == 1 + + def test_get_treatment_with_config(self): + """Test client.get_treatment_with_config().""" + _get_treatment_with_config(self.factory) + + def test_get_treatments(self): + """Test client.get_treatments().""" + _get_treatments(self.factory) + + def test_get_treatment_with_config(self): + """Test client.get_treatment_with_config().""" + _get_treatment_with_config(self.factory) + + def test_get_treatments_with_config(self): + """Test client.get_treatments_with_config().""" + _get_treatments_with_config(self.factory) + # testing multiple splitNames + client = self.factory.client() + result = client.get_treatments_with_config('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == ('on', None) + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + assert result['invalid_feature'] == ('control', None) + assert result['sample_feature'] == ('off', None) + _validate_last_impressions(client,) + + def test_get_treatments_by_flag_set(self): + """Test client.get_treatments_by_flag_set().""" + _get_treatments_by_flag_set(self.factory) + + def test_get_treatments_by_flag_sets(self): + """Test client.get_treatments_by_flag_sets().""" + _get_treatments_by_flag_sets(self.factory) + client = self.factory.client() + result = client.get_treatments_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': 'on', + 'whitelist_feature': 'off', + 'all_feature': 'on' + } + _validate_last_impressions(client, ) + + def test_get_treatments_with_config_by_flag_set(self): + """Test client.get_treatments_with_config_by_flag_set().""" + _get_treatments_with_config_by_flag_set(self.factory) + + def test_get_treatments_with_config_by_flag_sets(self): + """Test client.get_treatments_with_config_by_flag_sets().""" + _get_treatments_with_config_by_flag_sets(self.factory) + client = self.factory.client() + result = client.get_treatments_with_config_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), + 'whitelist_feature': ('off', None), + 'all_feature': ('on', None) + } + _validate_last_impressions(client, ) + + def test_track(self): + """Test client.track().""" + _track(self.factory) + + def test_manager_methods(self): + """Test manager.split/splits.""" + _manager_methods(self.factory) + + def teardown_method(self): + """Clear pluggable cache.""" + keys_to_delete = [ + "SPLITIO.segment.human_beigns", + "SPLITIO.segment.employees.till", + "SPLITIO.split.sample_feature", + "SPLITIO.splits.till", + "SPLITIO.split.killed_feature", + "SPLITIO.split.all_feature", + "SPLITIO.split.whitelist_feature", + "SPLITIO.segment.employees", + "SPLITIO.split.regex_test", + "SPLITIO.segment.human_beigns.till", + "SPLITIO.split.boolean_test", + "SPLITIO.split.dependency_test", + "SPLITIO.split.set.set1", + "SPLITIO.split.set.set2", + "SPLITIO.split.set.set3", + "SPLITIO.split.set.set4", + "SPLITIO.split.rbs_feature_flag", + "SPLITIO.rbsegments.till", + "SPLITIO.rbsegments.sample_rule_based_segment" + ] + for key in keys_to_delete: + self.pluggable_storage_adapter.delete(key) + +class PluggableNoneIntegrationTests(object): + """Pluggable storage-based integration tests.""" + + def setup_method(self): + """Prepare storages with test data.""" + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + self.pluggable_storage_adapter = StorageMockAdapter() + split_storage = PluggableSplitStorage(self.pluggable_storage_adapter) + segment_storage = PluggableSegmentStorage(self.pluggable_storage_adapter) + rb_segment_storage = PluggableRuleBasedSegmentsStorage(self.pluggable_storage_adapter) + telemetry_pluggable_storage = PluggableTelemetryStorage(self.pluggable_storage_adapter, metadata) + telemetry_producer = TelemetryStorageProducer(telemetry_pluggable_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': PluggableImpressionsStorage(self.pluggable_storage_adapter, metadata), + 'events': PluggableEventsStorage(self.pluggable_storage_adapter, metadata), + 'telemetry': telemetry_pluggable_storage + } + imp_counter = ImpressionsCounter() + unique_keys_tracker = UniqueKeysTracker() + unique_keys_synchronizer, clear_filter_sync, self.unique_keys_task, \ + clear_filter_task, impressions_count_sync, impressions_count_task, \ + imp_strategy, none_strategy = set_classes('PLUGGABLE', ImpressionsMode.NONE, self.pluggable_storage_adapter, imp_counter, unique_keys_tracker) + impmanager = ImpressionsManager(imp_strategy, none_strategy, telemetry_runtime_producer) # no listener + + recorder = StandardRecorder(impmanager, storages['events'], + storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, unique_keys_tracker=unique_keys_tracker, imp_counter=imp_counter) + + synchronizers = SplitSynchronizers(None, None, None, None, + impressions_count_sync, + None, + unique_keys_synchronizer, + clear_filter_sync + ) + + tasks = SplitTasks(None, None, None, None, + impressions_count_task, + None, + self.unique_keys_task, + clear_filter_task + ) + + synchronizer = RedisSynchronizer(synchronizers, tasks) + + manager = RedisManager(synchronizer) + manager.start() + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + events_queue = queue.Queue() + self.factory = SplitFactory('some_api_key', + storages, + True, + recorder, + events_queue, + events_manager, + manager, + sdk_ready_flag=None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + + # Adding data to storage + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + if split.get('sets') is not None: + for flag_set in split.get('sets'): + self.pluggable_storage_adapter.push_items(split_storage._flag_set_prefix.format(flag_set=flag_set), split['name']) + self.pluggable_storage_adapter.set(split_storage._prefix.format(feature_flag_name=split['name']), split) + self.pluggable_storage_adapter.set(split_storage._feature_flag_till_prefix, data['ff']['t']) + + for rbs in data['rbs']['d']: + self.pluggable_storage_adapter.set(rb_segment_storage._prefix.format(segment_name=rbs['name']), rbs) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + self.pluggable_storage_adapter.set(segment_storage._prefix.format(segment_name=data['name']), set(data['added'])) + self.pluggable_storage_adapter.set(segment_storage._segment_till_prefix.format(segment_name=data['name']), data['till']) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + self.pluggable_storage_adapter.set(segment_storage._prefix.format(segment_name=data['name']), set(data['added'])) + self.pluggable_storage_adapter.set(segment_storage._segment_till_prefix.format(segment_name=data['name']), data['till']) + self.client = self.factory.client() + + def test_get_treatment(self): + """Test client.get_treatment().""" + _get_treatment(self.factory) + assert self.pluggable_storage_adapter._keys['SPLITIO.impressions'] == [] + + def test_get_treatments(self): + """Test client.get_treatments().""" + _get_treatments(self.factory) + result = self.client.get_treatments('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == 'on' + assert result['killed_feature'] == 'defTreatment' + assert result['invalid_feature'] == 'control' + assert result['sample_feature'] == 'off' + + assert self.pluggable_storage_adapter._keys['SPLITIO.impressions'] == [] + + def test_get_treatments_with_config(self): + """Test client.get_treatments_with_config().""" + _get_treatments_with_config(self.factory) + result = self.client.get_treatments_with_config('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == ('on', None) + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + assert result['invalid_feature'] == ('control', None) + assert result['sample_feature'] == ('off', None) + assert self.pluggable_storage_adapter._keys['SPLITIO.impressions'] == [] + + def test_get_treatments_by_flag_set(self): + """Test client.get_treatments_by_flag_set().""" + _get_treatments_by_flag_set(self.factory) + assert self.pluggable_storage_adapter._keys['SPLITIO.impressions'] == [] + + def test_get_treatments_by_flag_sets(self): + """Test client.get_treatments_by_flag_sets().""" + _get_treatments_by_flag_sets(self.factory) + result = self.client.get_treatments_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': 'on', + 'whitelist_feature': 'off', + 'all_feature': 'on' + } + assert self.pluggable_storage_adapter._keys['SPLITIO.impressions'] == [] + + def test_get_treatments_with_config_by_flag_set(self): + """Test client.get_treatments_with_config_by_flag_set().""" + _get_treatments_with_config_by_flag_set(self.factory) + assert self.pluggable_storage_adapter._keys['SPLITIO.impressions'] == [] + + def test_get_treatments_with_config_by_flag_sets(self): + """Test client.get_treatments_with_config_by_flag_sets().""" + _get_treatments_with_config_by_flag_sets(self.factory) + result = self.client.get_treatments_with_config_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), + 'whitelist_feature': ('off', None), + 'all_feature': ('on', None) + } + assert self.pluggable_storage_adapter._keys['SPLITIO.impressions'] == [] + + def test_track(self): + """Test client.track().""" + _track(self.factory) + + def test_mtk(self): + self.client.get_treatment('user1', 'sample_feature') + self.client.get_treatment('invalidKey', 'sample_feature') + self.client.get_treatment('invalidKey2', 'sample_feature') + self.client.get_treatment('user22', 'invalidFeature') + self.unique_keys_task._task.force_execution() + time.sleep(1) + + assert(json.loads(self.pluggable_storage_adapter._keys['SPLITIO.uniquekeys'][0])["f"] =="sample_feature") + assert(json.loads(self.pluggable_storage_adapter._keys['SPLITIO.uniquekeys'][0])["ks"].sort() == + ["invalidKey2", "invalidKey", "user1"].sort()) + event = threading.Event() + self.factory.destroy(event) + event.wait() + +class InMemoryImpressionsToggleIntegrationTests(object): + """InMemory storage-based impressions toggle integration tests.""" + + def test_optimized(self): + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + + split_storage.update([splits.from_raw(splits_json['splitChange1_1']['ff']['d'][0]), + splits.from_raw(splits_json['splitChange1_1']['ff']['d'][1]), + splits.from_raw(splits_json['splitChange1_1']['ff']['d'][2]) + ], [], -1) + + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': InMemoryRuleBasedSegmentStorage(events_queue), + 'impressions': InMemoryImpressionStorage(5000, telemetry_runtime_producer), + 'events': InMemoryEventStorage(5000, telemetry_runtime_producer), + } + impmanager = ImpressionsManager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorder(impmanager, storages['events'], storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, None, UniqueKeysTracker(), ImpressionsCounter()) + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + # Since we are passing None as SDK_Ready event, the factory will use the Redis telemetry call, using try catch to ignore the exception. + try: + factory = SplitFactory('some_api_key', + storages, + True, + recorder, + events_queue, + events_manager, + None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global"), {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + except: + pass + + try: + client = factory.client() + except: + pass + + assert client.get_treatment('user1', 'SPLIT_1') == 'off' + assert client.get_treatment('user1', 'SPLIT_2') == 'on' + assert client.get_treatment('user1', 'SPLIT_3') == 'on' + imp_storage = client._factory._get_storage('impressions') + impressions = imp_storage.pop_many(10) + assert len(impressions) == 2 + assert impressions[0].feature_name == 'SPLIT_1' + assert impressions[1].feature_name == 'SPLIT_2' + assert client._recorder._unique_keys_tracker._cache == {'SPLIT_3': {'user1'}} + imps_count = client._recorder._imp_counter.pop_all() + assert len(imps_count) == 1 + assert imps_count[0].feature == 'SPLIT_3' + assert imps_count[0].count == 1 + assert client.get_treatment('user1', 'incorrect_feature') == 'on-global' + assert client.get_treatment('user1', 'fallback_feature') == 'on-local' + + def test_debug(self): + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + + split_storage.update([splits.from_raw(splits_json['splitChange1_1']['ff']['d'][0]), + splits.from_raw(splits_json['splitChange1_1']['ff']['d'][1]), + splits.from_raw(splits_json['splitChange1_1']['ff']['d'][2]) + ], [], -1) + + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': InMemoryRuleBasedSegmentStorage(events_queue), + 'impressions': InMemoryImpressionStorage(5000, telemetry_runtime_producer), + 'events': InMemoryEventStorage(5000, telemetry_runtime_producer), + } + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorder(impmanager, storages['events'], storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, None, UniqueKeysTracker(), ImpressionsCounter()) + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + internal_events_task = EventsTask(events_manager.notify_internal_event, events_queue) + # Since we are passing None as SDK_Ready event, the factory will use the Redis telemetry call, using try catch to ignore the exception. + try: + factory = SplitFactory('some_api_key', + storages, + True, + recorder, + events_queue, + events_manager, + None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global"), {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + internal_events_task.start() + except: + pass + + try: + client = factory.client() + except: + pass + + assert client.get_treatment('user1', 'SPLIT_1') == 'off' + assert client.get_treatment('user1', 'SPLIT_2') == 'on' + assert client.get_treatment('user1', 'SPLIT_3') == 'on' + imp_storage = client._factory._get_storage('impressions') + impressions = imp_storage.pop_many(10) + assert len(impressions) == 2 + assert impressions[0].feature_name == 'SPLIT_1' + assert impressions[1].feature_name == 'SPLIT_2' + assert client._recorder._unique_keys_tracker._cache == {'SPLIT_3': {'user1'}} + imps_count = client._recorder._imp_counter.pop_all() + assert len(imps_count) == 1 + assert imps_count[0].feature == 'SPLIT_3' + assert imps_count[0].count == 1 + assert client.get_treatment('user1', 'incorrect_feature') == 'on-global' + assert client.get_treatment('user1', 'fallback_feature') == 'on-local' + + def test_none(self): + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + + split_storage.update([splits.from_raw(splits_json['splitChange1_1']['ff']['d'][0]), + splits.from_raw(splits_json['splitChange1_1']['ff']['d'][1]), + splits.from_raw(splits_json['splitChange1_1']['ff']['d'][2]) + ], [], -1) + + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': InMemoryRuleBasedSegmentStorage(events_queue), + 'impressions': InMemoryImpressionStorage(5000, telemetry_runtime_producer), + 'events': InMemoryEventStorage(5000, telemetry_runtime_producer), + } + impmanager = ImpressionsManager(StrategyNoneMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorder(impmanager, storages['events'], storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, None, UniqueKeysTracker(), ImpressionsCounter()) + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + internal_events_task = EventsTask(events_manager.notify_internal_event, events_queue) + # Since we are passing None as SDK_Ready event, the factory will use the Redis telemetry call, using try catch to ignore the exception. + try: + factory = SplitFactory('some_api_key', + storages, + True, + recorder, + events_queue, + events_queue, + None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global"), {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + internal_events_task.start() + except: + pass + + try: + client = factory.client() + except: + pass + + assert client.get_treatment('user1', 'SPLIT_1') == 'off' + assert client.get_treatment('user1', 'SPLIT_2') == 'on' + assert client.get_treatment('user1', 'SPLIT_3') == 'on' + imp_storage = client._factory._get_storage('impressions') + impressions = imp_storage.pop_many(10) + assert len(impressions) == 0 + assert client._recorder._unique_keys_tracker._cache == {'SPLIT_1': {'user1'}, 'SPLIT_2': {'user1'}, 'SPLIT_3': {'user1'}} + imps_count = client._recorder._imp_counter.pop_all() + assert len(imps_count) == 3 + assert imps_count[0].feature == 'SPLIT_1' + assert imps_count[0].count == 1 + assert imps_count[1].feature == 'SPLIT_2' + assert imps_count[1].count == 1 + assert imps_count[2].feature == 'SPLIT_3' + assert imps_count[2].count == 1 + assert client.get_treatment('user1', 'incorrect_feature') == 'on-global' + assert client.get_treatment('user1', 'fallback_feature') == 'on-local' + +class RedisImpressionsToggleIntegrationTests(object): + """Run impression toggle tests for Redis.""" + + def test_optimized(self): + """Prepare storages with test data.""" + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + redis_client = build(DEFAULT_CONFIG.copy()) + split_storage = RedisSplitStorage(redis_client, True) + segment_storage = RedisSegmentStorage(redis_client) + + redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][0]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][0])) + redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][1]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][1])) + redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][2]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][2])) + redis_client.set(split_storage._FEATURE_FLAG_TILL_KEY, -1) + + telemetry_redis_storage = RedisTelemetryStorage(redis_client, metadata) + telemetry_producer = TelemetryStorageProducer(telemetry_redis_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': RedisRuleBasedSegmentsStorage(redis_client), + 'impressions': RedisImpressionsStorage(redis_client, metadata), + 'events': RedisEventsStorage(redis_client, metadata), + } + impmanager = ImpressionsManager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = PipelinedRecorder(redis_client.pipeline, impmanager, + storages['events'], storages['impressions'], telemetry_redis_storage, unique_keys_tracker=UniqueKeysTracker(), imp_counter=ImpressionsCounter()) + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + events_queue = queue.Queue() + factory = SplitFactory('some_api_key', + storages, + True, + recorder, + events_queue, + events_manager, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global"), {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + + try: + client = factory.client() + except: + pass + + assert client.get_treatment('user1', 'SPLIT_1') == 'off' + assert client.get_treatment('user2', 'SPLIT_2') == 'on' + assert client.get_treatment('user3', 'SPLIT_3') == 'on' + time.sleep(0.2) + + imp_storage = factory._storages['impressions'] + impressions = [] + while True: + impression = redis_client.lpop(imp_storage.IMPRESSIONS_QUEUE_KEY) + if impression is None: + break + impressions.append(json.loads(impression)) + + assert len(impressions) == 2 + assert impressions[0]['i']['f'] == 'SPLIT_1' + assert impressions[1]['i']['f'] == 'SPLIT_2' + assert client._recorder._unique_keys_tracker._cache == {'SPLIT_3': {'user3'}} + imps_count = client._recorder._imp_counter.pop_all() + assert len(imps_count) == 1 + assert imps_count[0].feature == 'SPLIT_3' + assert imps_count[0].count == 1 + assert client.get_treatment('user1', 'incorrect_feature') == 'on-global' + assert client.get_treatment('user1', 'fallback_feature') == 'on-local' + self.clear_cache() + client.destroy() + + def test_debug(self): + """Prepare storages with test data.""" + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + redis_client = build(DEFAULT_CONFIG.copy()) + split_storage = RedisSplitStorage(redis_client, True) + segment_storage = RedisSegmentStorage(redis_client) + + redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][0]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][0])) + redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][1]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][1])) + redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][2]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][2])) + redis_client.set(split_storage._FEATURE_FLAG_TILL_KEY, -1) + + telemetry_redis_storage = RedisTelemetryStorage(redis_client, metadata) + telemetry_producer = TelemetryStorageProducer(telemetry_redis_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': RedisRuleBasedSegmentsStorage(redis_client), + 'impressions': RedisImpressionsStorage(redis_client, metadata), + 'events': RedisEventsStorage(redis_client, metadata), + } + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = PipelinedRecorder(redis_client.pipeline, impmanager, + storages['events'], storages['impressions'], telemetry_redis_storage, unique_keys_tracker=UniqueKeysTracker(), imp_counter=ImpressionsCounter()) + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + factory = SplitFactory('some_api_key', + storages, + True, + recorder, + queue.Queue(), + events_manager, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global"), {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + + try: + client = factory.client() + except: + pass + + assert client.get_treatment('user1', 'SPLIT_1') == 'off' + assert client.get_treatment('user2', 'SPLIT_2') == 'on' + assert client.get_treatment('user3', 'SPLIT_3') == 'on' + time.sleep(0.2) + + imp_storage = factory._storages['impressions'] + impressions = [] + while True: + impression = redis_client.lpop(imp_storage.IMPRESSIONS_QUEUE_KEY) + if impression is None: + break + impressions.append(json.loads(impression)) + + assert len(impressions) == 2 + assert impressions[0]['i']['f'] == 'SPLIT_1' + assert impressions[1]['i']['f'] == 'SPLIT_2' + assert client._recorder._unique_keys_tracker._cache == {'SPLIT_3': {'user3'}} + imps_count = client._recorder._imp_counter.pop_all() + assert len(imps_count) == 1 + assert imps_count[0].feature == 'SPLIT_3' + assert imps_count[0].count == 1 + assert client.get_treatment('user1', 'incorrect_feature') == 'on-global' + assert client.get_treatment('user1', 'fallback_feature') == 'on-local' + self.clear_cache() + client.destroy() + + def test_none(self): + """Prepare storages with test data.""" + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + redis_client = build(DEFAULT_CONFIG.copy()) + split_storage = RedisSplitStorage(redis_client, True) + segment_storage = RedisSegmentStorage(redis_client) + + redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][0]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][0])) + redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][1]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][1])) + redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][2]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][2])) + redis_client.set(split_storage._FEATURE_FLAG_TILL_KEY, -1) + + telemetry_redis_storage = RedisTelemetryStorage(redis_client, metadata) + telemetry_producer = TelemetryStorageProducer(telemetry_redis_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': RedisRuleBasedSegmentsStorage(redis_client), + 'impressions': RedisImpressionsStorage(redis_client, metadata), + 'events': RedisEventsStorage(redis_client, metadata), + } + impmanager = ImpressionsManager(StrategyNoneMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = PipelinedRecorder(redis_client.pipeline, impmanager, + storages['events'], storages['impressions'], telemetry_redis_storage, unique_keys_tracker=UniqueKeysTracker(), imp_counter=ImpressionsCounter()) + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + factory = SplitFactory('some_api_key', + storages, + True, + recorder, + queue.Queue(), + events_manager, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(FallbackTreatment("on-global"), {'fallback_feature': FallbackTreatment("on-local", '{"prop":"val"}')})) + ) # pylint:disable=attribute-defined-outside-init + + try: + client = factory.client() + except: + pass + + assert client.get_treatment('user1', 'SPLIT_1') == 'off' + assert client.get_treatment('user2', 'SPLIT_2') == 'on' + assert client.get_treatment('user3', 'SPLIT_3') == 'on' + time.sleep(0.2) + + imp_storage = factory._storages['impressions'] + impressions = [] + while True: + impression = redis_client.lpop(imp_storage.IMPRESSIONS_QUEUE_KEY) + if impression is None: + break + impressions.append(json.loads(impression)) + + assert len(impressions) == 0 + assert client._recorder._unique_keys_tracker._cache == {'SPLIT_1': {'user1'}, 'SPLIT_2': {'user2'}, 'SPLIT_3': {'user3'}} + imps_count = client._recorder._imp_counter.pop_all() + assert len(imps_count) == 3 + assert imps_count[0].feature == 'SPLIT_1' + assert imps_count[0].count == 1 + assert imps_count[1].feature == 'SPLIT_2' + assert imps_count[1].count == 1 + assert imps_count[2].feature == 'SPLIT_3' + assert imps_count[2].count == 1 + assert client.get_treatment('user1', 'incorrect_feature') == 'on-global' + assert client.get_treatment('user1', 'fallback_feature') == 'on-local' + self.clear_cache() + client.destroy() + + def clear_cache(self): + """Clear redis cache.""" + keys_to_delete = [ + "SPLITIO.split.SPLIT_3", + "SPLITIO.splits.till", + "SPLITIO.split.SPLIT_2", + "SPLITIO.split.SPLIT_1", + "SPLITIO.telemetry.latencies" + ] + + redis_client = RedisAdapter(StrictRedis()) + for key in keys_to_delete: + redis_client.delete(key) + +class InMemoryEventsNotificationTests(object): + """Inmemory storage-based events notification tests.""" + + ready_flag = False + timeout_flag = False + + def test_sdk_ready(self): + """Prepare storages with test data.""" + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + split_storage.update([splits.from_raw(split)], [], 0) + + for rbs in data['rbs']['d']: + rb_segment_storage.update([rule_based_segments.from_raw(rbs)], [], 0) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + segment_storage.put(segments.from_raw(data)) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + segment_storage.put(segments.from_raw(data)) + + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': InMemoryImpressionStorage(5000, telemetry_runtime_producer), + 'events': InMemoryEventStorage(5000, telemetry_runtime_producer), + } + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorder(impmanager, storages['events'], storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, imp_counter=ImpressionsCounter()) + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + internal_events_task = EventsTask(events_manager.notify_internal_event, events_queue) + + # Since we are passing None as SDK_Ready event, the factory will use the Redis telemetry call, using try catch to ignore the exception. + try: + factory = SplitFactory('some_api_key', + storages, + True, + recorder, + events_queue, + events_manager, + None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + internal_events_task.start() + except: + pass + + client = factory.client() + client.on(SdkEvent.SDK_READY, self._ready_callback) + factory.block_until_ready(5) + assert self.ready_flag + + """Shut down the factory.""" + event = threading.Event() + factory.destroy(event) + event.wait() + + def test_sdk_ready_fire_later(self): + """Prepare storages with test data.""" + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorage(events_queue) + + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + split_storage.update([splits.from_raw(split)], [], 0) + + for rbs in data['rbs']['d']: + rb_segment_storage.update([rule_based_segments.from_raw(rbs)], [], 0) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + segment_storage.put(segments.from_raw(data)) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + segment_storage.put(segments.from_raw(data)) + + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': InMemoryImpressionStorage(5000, telemetry_runtime_producer), + 'events': InMemoryEventStorage(5000, telemetry_runtime_producer), + } + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorder(impmanager, storages['events'], storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, imp_counter=ImpressionsCounter()) + events_manager = EventsManager(EventsManagerConfig(), EventsDelivery()) + internal_events_task = EventsTask(events_manager.notify_internal_event, events_queue) + + # Since we are passing None as SDK_Ready event, the factory will use the Redis telemetry call, using try catch to ignore the exception. + try: + factory = SplitFactory('some_api_key', + storages, + True, + recorder, + events_queue, + events_manager, + None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + internal_events_task.start() + except: + pass + + client = factory.client() + factory.block_until_ready(5) + + assert client.get_treatment('user1', 'sample_feature', evaluation_options=EvaluationOptions({"prop": "value"})) == 'on' + + self.ready_flag = False + client.on(SdkEvent.SDK_READY, self._ready_callback) + assert self.ready_flag + + """Shut down the factory.""" + event = threading.Event() + factory.destroy(event) + event.wait() + + def _ready_callback(self, metadata): + self.ready_flag = True + + def _timeout_callback(self, metadata): + self.timeout_flag = True + +class InMemoryEventsNotificationAsyncTests(object): + """Inmemory storage-based events notification tests.""" + + ready_flag = False + timeout_flag = False + + @pytest.mark.asyncio + async def test_sdk_ready(self): + """Prepare storages with test data.""" + events_queue = asyncio.Queue() + split_storage = InMemorySplitStorageAsync(events_queue) + segment_storage = InMemorySegmentStorageAsync(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(events_queue) + + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + await split_storage.update([splits.from_raw(split)], [], 0) + + for rbs in data['rbs']['d']: + await rb_segment_storage.update([rule_based_segments.from_raw(rbs)], [], 0) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await segment_storage.put(segments.from_raw(data)) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await segment_storage.put(segments.from_raw(data)) + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': InMemoryImpressionStorageAsync(5000, telemetry_runtime_producer), + 'events': InMemoryEventStorageAsync(5000, telemetry_runtime_producer), + } + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorderAsync(impmanager, storages['events'], storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, imp_counter=ImpressionsCounter()) + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + internal_events_task = EventsTaskAsync(events_manager.notify_internal_event, events_queue) + + # Since we are passing None as SDK_Ready event, the factory will use the Redis telemetry call, using try catch to ignore the exception. + try: + factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + events_queue, + events_manager, + None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + internal_events_task.start() + except: + pass + + client = factory.client() + await client.on(SdkEvent.SDK_READY, self._ready_callback) + await factory.block_until_ready(5) + assert self.ready_flag + + """Shut down the factory.""" + await internal_events_task.stop() + await factory.destroy() + + @pytest.mark.asyncio + async def test_sdk_ready_fire_later(self): + """Prepare storages with test data.""" + events_queue = asyncio.Queue() + split_storage = InMemorySplitStorageAsync(events_queue) + segment_storage = InMemorySegmentStorageAsync(events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(events_queue) + + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + await split_storage.update([splits.from_raw(split)], [], 0) + + for rbs in data['rbs']['d']: + await rb_segment_storage.update([rule_based_segments.from_raw(rbs)], [], 0) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await segment_storage.put(segments.from_raw(data)) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await segment_storage.put(segments.from_raw(data)) + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': InMemoryImpressionStorageAsync(5000, telemetry_runtime_producer), + 'events': InMemoryEventStorageAsync(5000, telemetry_runtime_producer), + } + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorderAsync(impmanager, storages['events'], storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, imp_counter=ImpressionsCounter()) + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + internal_events_task = EventsTaskAsync(events_manager.notify_internal_event, events_queue) + + # Since we are passing None as SDK_Ready event, the factory will use the Redis telemetry call, using try catch to ignore the exception. + try: + factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + events_queue, + events_manager, + None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + internal_events_task.start() + except: + pass + + client = factory.client() + await factory.block_until_ready(5) + await client.on(SdkEvent.SDK_READY, self._ready_callback) + + """Shut down the factory.""" + await internal_events_task.stop() + await factory.destroy() + + async def _ready_callback(self, metadata): + self.ready_flag = True + + async def _timeout_callback(self, metadata): + self.timeout_flag = True + +class InMemoryIntegrationAsyncTests(object): + """Inmemory storage-based integration tests.""" + + def setup_method(self): + self.setup_task = asyncio.get_event_loop().create_task(self._setup_method()) + + async def _setup_method(self): + """Prepare storages with test data.""" + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + await split_storage.update([splits.from_raw(split)], [], -1) + + for rbs in data['rbs']['d']: + await rb_segment_storage.update([rule_based_segments.from_raw(rbs)], [], 0) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await segment_storage.put(segments.from_raw(data)) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await segment_storage.put(segments.from_raw(data)) + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': InMemoryImpressionStorageAsync(5000, telemetry_runtime_producer), + 'events': InMemoryEventStorageAsync(5000, telemetry_runtime_producer), + } + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorderAsync(impmanager, storages['events'], storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, imp_counter=ImpressionsCounter()) + # Since we are passing None as SDK_Ready event, the factory will use the Redis telemetry call, using try catch to ignore the exception. + try: + self.factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + internal_events_queue, + events_manager, + None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + except: + pass + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(self.factory).ready = ready_property + + @pytest.mark.asyncio + async def test_get_treatment(self): + """Test client.get_treatment().""" + await _get_treatment_async(self.factory) + + @pytest.mark.asyncio + async def test_get_treatment_with_config(self): + """Test client.get_treatment_with_config().""" + await _get_treatment_with_config_async(self.factory) + + @pytest.mark.asyncio + async def test_get_treatments(self): + await _get_treatments_async(self.factory) + # testing multiple splitNames + client = self.factory.client() + result = await client.get_treatments('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == 'on' + assert result['killed_feature'] == 'defTreatment' + assert result['invalid_feature'] == 'control' + assert result['sample_feature'] == 'off' + await _validate_last_impressions_async( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('sample_feature', 'invalidKey', 'off') + ) + + @pytest.mark.asyncio + async def test_get_treatments_with_config(self): + """Test client.get_treatments_with_config().""" + await _get_treatments_with_config_async(self.factory) + # testing multiple splitNames + client = self.factory.client() + result = await client.get_treatments_with_config('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == ('on', None) + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + assert result['invalid_feature'] == ('control', None) + assert result['sample_feature'] == ('off', None) + await _validate_last_impressions_async( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('sample_feature', 'invalidKey', 'off'), + ) + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_set(self): + """Test client.get_treatments_by_flag_set().""" + await _get_treatments_by_flag_set_async(self.factory) + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_sets(self): + """Test client.get_treatments_by_flag_sets().""" + await _get_treatments_by_flag_sets_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': 'on', + 'whitelist_feature': 'off', + 'all_feature': 'on' + } + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_set(self): + """Test client.get_treatments_with_config_by_flag_set().""" + await _get_treatments_with_config_by_flag_set_async(self.factory) + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_sets(self): + """Test client.get_treatments_with_config_by_flag_sets().""" + await _get_treatments_with_config_by_flag_sets_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_with_config_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), + 'whitelist_feature': ('off', None), + 'all_feature': ('on', None) + } + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + + @pytest.mark.asyncio + async def test_track(self): + """Test client.track().""" + await _track_async(self.factory) + + @pytest.mark.asyncio + async def test_manager_methods(self): + """Test manager.split/splits.""" + await _manager_methods_async(self.factory) + await self.factory.destroy() + +class InMemoryOptimizedIntegrationAsyncTests(object): + """Inmemory storage-based integration tests.""" + + def setup_method(self): + self.setup_task = asyncio.get_event_loop().create_task(self._setup_method()) + + async def _setup_method(self): + """Prepare storages with test data.""" + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + rb_segment_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + await split_storage.update([splits.from_raw(split)], [], -1) + + for rbs in data['rbs']['d']: + await rb_segment_storage.update([rule_based_segments.from_raw(rbs)], [], 0) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await segment_storage.put(segments.from_raw(data)) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await segment_storage.put(segments.from_raw(data)) + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': InMemoryImpressionStorageAsync(5000, telemetry_runtime_producer), + 'events': InMemoryEventStorageAsync(5000, telemetry_runtime_producer), + } + impmanager = ImpressionsManager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorderAsync(impmanager, storages['events'], storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, + imp_counter = ImpressionsCounter()) + # Since we are passing None as SDK_Ready event, the factory will use the Redis telemetry call, using try catch to ignore the exception. + try: + self.factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + internal_events_queue, + events_manager, + None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + except: + pass + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(self.factory).ready = ready_property + + @pytest.mark.asyncio + async def test_get_treatment(self): + """Test client.get_treatment().""" + await _get_treatment_async(self.factory) + + @pytest.mark.asyncio + async def test_get_treatments(self): + """Test client.get_treatments().""" + await _get_treatments_async(self.factory) + # testing multiple splitNames + client = self.factory.client() + result = await client.get_treatments('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == 'on' + assert result['killed_feature'] == 'defTreatment' + assert result['invalid_feature'] == 'control' + assert result['sample_feature'] == 'off' + assert self.factory._storages['impressions']._impressions.qsize() == 0 + + @pytest.mark.asyncio + async def test_get_treatments_with_config(self): + """Test client.get_treatments_with_config().""" + await _get_treatments_with_config_async(self.factory) + # testing multiple splitNames + client = self.factory.client() + result = await client.get_treatments_with_config('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == ('on', None) + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + assert result['invalid_feature'] == ('control', None) + assert result['sample_feature'] == ('off', None) + await _validate_last_impressions_async(client,) + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_set(self): + """Test client.get_treatments_by_flag_set().""" + await _get_treatments_by_flag_set_async(self.factory) + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_sets(self): + """Test client.get_treatments_by_flag_sets().""" + await _get_treatments_by_flag_sets_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': 'on', + 'whitelist_feature': 'off', + 'all_feature': 'on' + } + await _validate_last_impressions_async(client, ) + assert self.factory._storages['impressions']._impressions.qsize() == 0 + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_set(self): + """Test client.get_treatments_with_config_by_flag_set().""" + await _get_treatments_with_config_by_flag_set_async(self.factory) + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_sets(self): + """Test client.get_treatments_with_config_by_flag_sets().""" + await _get_treatments_with_config_by_flag_sets_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_with_config_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), + 'whitelist_feature': ('off', None), + 'all_feature': ('on', None) + } + await _validate_last_impressions_async(client, ) + + @pytest.mark.asyncio + async def test_manager_methods(self): + """Test manager.split/splits.""" + await _manager_methods_async(self.factory) + + @pytest.mark.asyncio + async def test_track(self): + """Test client.track().""" + await _track_async(self.factory) + await self.factory.destroy() + +class InMemoryOldSpecIntegrationAsyncTests(object): + """Inmemory storage-based integration tests.""" + + def setup_method(self): + self.setup_task = asyncio.get_event_loop().create_task(self._setup_method()) + + async def _setup_method(self): + """Prepare storages with test data.""" + + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'split_old_spec.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + + split_changes = { + -1: data, + 1457726098069: {"splits": [], "till": 1457726098069, "since": 1457726098069} + } + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + segment_employee = json.loads(flo.read()) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + segment_human = json.loads(flo.read()) + + segment_changes = { + ("employees", -1): segment_employee, + ("employees", 1457474612832): {"name": "employees","added": [],"removed": [],"since": 1457474612832,"till": 1457474612832}, + ("human_beigns", -1): segment_human, + ("human_beigns", 1457102183278): {"name": "employees","added": [],"removed": [],"since": 1457102183278,"till": 1457102183278}, + } + + split_backend_requests = Queue() + self.split_backend = SplitMockServer(split_changes, segment_changes, split_backend_requests, + {'auth_response': {'pushEnabled': False}}, True) + self.split_backend.start() + + kwargs = { + 'sdk_api_base_url': 'http://localhost:%d/api' % self.split_backend.port(), + 'events_api_base_url': 'http://localhost:%d/api' % self.split_backend.port(), + 'auth_api_base_url': 'http://localhost:%d/api' % self.split_backend.port(), + 'config': {'connectTimeout': 10000, + 'streamingEnabled': False, + 'impressionsMode': 'debug', + 'fallbackTreatments': FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')}) + } + } + + self.factory = await get_factory_async('some_apikey', **kwargs) + await self.factory.block_until_ready(1) + assert self.factory.ready + + @pytest.mark.asyncio + async def test_get_treatment(self): + """Test client.get_treatment().""" + await self.setup_task + await _get_treatment_async(self.factory, True) + await self.factory.destroy() + self.split_backend.stop() + + @pytest.mark.asyncio + async def test_get_treatment_with_config(self): + """Test client.get_treatment_with_config().""" + await self.setup_task + await _get_treatment_with_config_async(self.factory) + await self.factory.destroy() + self.split_backend.stop() + + @pytest.mark.asyncio + async def test_get_treatments(self): + await self.setup_task + await _get_treatments_async(self.factory) + # testing multiple splitNames + client = self.factory.client() + result = await client.get_treatments('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == 'on' + assert result['killed_feature'] == 'defTreatment' + assert result['invalid_feature'] == 'control' + assert result['sample_feature'] == 'off' + await _validate_last_impressions_async( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('sample_feature', 'invalidKey', 'off') + ) + await self.factory.destroy() + self.split_backend.stop() + + @pytest.mark.asyncio + async def test_get_treatments_with_config(self): + """Test client.get_treatments_with_config().""" + await self.setup_task + await _get_treatments_with_config_async(self.factory) + # testing multiple splitNames + client = self.factory.client() + result = await client.get_treatments_with_config('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == ('on', None) + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + assert result['invalid_feature'] == ('control', None) + assert result['sample_feature'] == ('off', None) + await _validate_last_impressions_async( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('sample_feature', 'invalidKey', 'off'), + ) + await self.factory.destroy() + self.split_backend.stop() + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_set(self): + """Test client.get_treatments_by_flag_set().""" + await self.setup_task + await _get_treatments_by_flag_set_async(self.factory) + await self.factory.destroy() + self.split_backend.stop() + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_sets(self): + """Test client.get_treatments_by_flag_sets().""" + await self.setup_task + await _get_treatments_by_flag_sets_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': 'on', + 'whitelist_feature': 'off', + 'all_feature': 'on' + } + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + await self.factory.destroy() + self.split_backend.stop() + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_set(self): + """Test client.get_treatments_with_config_by_flag_set().""" + await self.setup_task + await _get_treatments_with_config_by_flag_set_async(self.factory) + await self.factory.destroy() + self.split_backend.stop() + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_sets(self): + """Test client.get_treatments_with_config_by_flag_sets().""" + await self.setup_task + await _get_treatments_with_config_by_flag_sets_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_with_config_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), + 'whitelist_feature': ('off', None), + 'all_feature': ('on', None) + } + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + await self.factory.destroy() + self.split_backend.stop() + + @pytest.mark.asyncio + async def test_track(self): + """Test client.track().""" + await self.setup_task + await _track_async(self.factory) + await self.factory.destroy() + self.split_backend.stop() + + @pytest.mark.asyncio + async def test_manager_methods(self): + """Test manager.split/splits.""" + await self.setup_task + await _manager_methods_async(self.factory, True) + await self.factory.destroy() + self.split_backend.stop() + +class RedisIntegrationAsyncTests(object): + """Redis storage-based integration tests.""" + + def setup_method(self): + self.setup_task = asyncio.get_event_loop().create_task(self._setup_method()) + + async def _setup_method(self): + """Prepare storages with test data.""" + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + redis_client = await build_async(DEFAULT_CONFIG.copy()) + await self._clear_cache(redis_client) + + split_storage = RedisSplitStorageAsync(redis_client) + segment_storage = RedisSegmentStorageAsync(redis_client) + rb_segment_storage = RedisRuleBasedSegmentsStorageAsync(redis_client) + + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + await redis_client.set(split_storage._get_key(split['name']), json.dumps(split)) + if split.get('sets') is not None: + for flag_set in split.get('sets'): + await redis_client.sadd(split_storage._get_flag_set_key(flag_set), split['name']) + await redis_client.set(split_storage._FEATURE_FLAG_TILL_KEY, data['ff']['t']) + + for rbs in data['rbs']['d']: + await redis_client.set(rb_segment_storage._get_key(rbs['name']), json.dumps(rbs)) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await redis_client.sadd(segment_storage._get_key(data['name']), *data['added']) + await redis_client.set(segment_storage._get_till_key(data['name']), data['till']) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await redis_client.sadd(segment_storage._get_key(data['name']), *data['added']) + await redis_client.set(segment_storage._get_till_key(data['name']), data['till']) + + telemetry_redis_storage = await RedisTelemetryStorageAsync.create(redis_client, metadata) + telemetry_producer = TelemetryStorageProducerAsync(telemetry_redis_storage) + telemetry_submitter = RedisTelemetrySubmitterAsync(telemetry_redis_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': RedisImpressionsStorageAsync(redis_client, metadata), + 'events': RedisEventsStorageAsync(redis_client, metadata), + } + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = PipelinedRecorderAsync(redis_client.pipeline, impmanager, storages['events'], + storages['impressions'], telemetry_redis_storage, imp_counter=ImpressionsCounter()) + self.factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + internal_events_queue, + events_manager, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + telemetry_submitter=telemetry_submitter, + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(self.factory).ready = ready_property + + @pytest.mark.asyncio + async def test_get_treatment(self): + """Test client.get_treatment().""" + await self.setup_task + await _get_treatment_async(self.factory) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatment_with_config(self): + """Test client.get_treatment_with_config().""" + await self.setup_task + await _get_treatment_with_config_async(self.factory) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments(self): + # testing multiple splitNames + await self.setup_task + await _get_treatments_async(self.factory) + client = self.factory.client() + result = await client.get_treatments('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == 'on' + assert result['killed_feature'] == 'defTreatment' + assert result['invalid_feature'] == 'control' + assert result['sample_feature'] == 'off' + await _validate_last_impressions_async( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('sample_feature', 'invalidKey', 'off') + ) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_with_config(self): + """Test client.get_treatments_with_config().""" + await self.setup_task + await _get_treatments_with_config_async(self.factory) + # testing multiple splitNames + client = self.factory.client() + result = await client.get_treatments_with_config('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == ('on', None) + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + assert result['invalid_feature'] == ('control', None) + assert result['sample_feature'] == ('off', None) + await _validate_last_impressions_async( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('sample_feature', 'invalidKey', 'off'), + ) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_set(self): + """Test client.get_treatments_by_flag_set().""" + await self.setup_task + await _get_treatments_by_flag_set_async(self.factory) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_sets(self): + """Test client.get_treatments_by_flag_sets().""" + await self.setup_task + await _get_treatments_by_flag_sets_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': 'on', + 'whitelist_feature': 'off', + 'all_feature': 'on' + } + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_set(self): + """Test client.get_treatments_with_config_by_flag_set().""" + await self.setup_task + await _get_treatments_with_config_by_flag_set_async(self.factory) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_sets(self): + """Test client.get_treatments_with_config_by_flag_sets().""" + await self.setup_task + await _get_treatments_with_config_by_flag_sets_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_with_config_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), + 'whitelist_feature': ('off', None), + 'all_feature': ('on', None) + } + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_track(self): + """Test client.track().""" + await self.setup_task + await _track_async(self.factory) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_manager_methods(self): + """Test manager.split/splits.""" + await self.setup_task + await _manager_methods_async(self.factory) + await self.factory.destroy() + await self._clear_cache(self.factory._storages['splits'].redis) + + async def _clear_cache(self, redis_client): + """Clear redis cache.""" + keys_to_delete = [ + "SPLITIO.split.sample_feature", + "SPLITIO.split.killed_feature", + "SPLITIO.split.regex_test", + "SPLITIO.segment.employees", + "SPLITIO.segment.human_beigns.till", + "SPLITIO.segment.human_beigns", + "SPLITIO.impressions", + "SPLITIO.split.boolean_test", + "SPLITIO.splits.till", + "SPLITIO.split.all_feature", + "SPLITIO.segment.employees.till", + "SPLITIO.split.whitelist_feature", + "SPLITIO.telemetry.latencies", + "SPLITIO.split.dependency_test", + "SPLITIO.split.rbs_feature_flag", + "SPLITIO.rbsegments.till", + "SPLITIO.rbsegments.sample_rule_based_segment" + ] + for key in keys_to_delete: + await redis_client.delete(key) + +class RedisWithCacheIntegrationAsyncTests(RedisIntegrationAsyncTests): + """Run the same tests as RedisIntegratioTests but with LRU/Expirable cache overlay.""" + + def setup_method(self): + self.setup_task = asyncio.get_event_loop().create_task(self._setup_method()) + + async def _setup_method(self): + """Prepare storages with test data.""" + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + redis_client = await build_async(DEFAULT_CONFIG.copy()) + await self._clear_cache(redis_client) + + split_storage = RedisSplitStorageAsync(redis_client, True) + segment_storage = RedisSegmentStorageAsync(redis_client) + rb_segment_storage = RedisRuleBasedSegmentsStorageAsync(redis_client) + + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + await redis_client.set(split_storage._get_key(split['name']), json.dumps(split)) + if split.get('sets') is not None: + for flag_set in split.get('sets'): + await redis_client.sadd(split_storage._get_flag_set_key(flag_set), split['name']) + await redis_client.set(split_storage._FEATURE_FLAG_TILL_KEY, data['ff']['t']) + + for rbs in data['rbs']['d']: + await redis_client.set(rb_segment_storage._get_key(rbs['name']), json.dumps(rbs)) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await redis_client.sadd(segment_storage._get_key(data['name']), *data['added']) + await redis_client.set(segment_storage._get_till_key(data['name']), data['till']) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await redis_client.sadd(segment_storage._get_key(data['name']), *data['added']) + await redis_client.set(segment_storage._get_till_key(data['name']), data['till']) + + telemetry_redis_storage = await RedisTelemetryStorageAsync.create(redis_client, metadata) + telemetry_producer = TelemetryStorageProducerAsync(telemetry_redis_storage) + telemetry_submitter = RedisTelemetrySubmitterAsync(telemetry_redis_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': RedisImpressionsStorageAsync(redis_client, metadata), + 'events': RedisEventsStorageAsync(redis_client, metadata), + } + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = PipelinedRecorderAsync(redis_client.pipeline, impmanager, storages['events'], + storages['impressions'], telemetry_redis_storage, imp_counter=ImpressionsCounter()) + self.factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + internal_events_queue, + events_manager, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + telemetry_submitter=telemetry_submitter, + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(self.factory).ready = ready_property + + +class LocalhostIntegrationAsyncTests(object): # pylint: disable=too-few-public-methods + """Client & Manager integration tests.""" + + @pytest.mark.asyncio + async def test_localhost_json_e2e(self): + """Instantiate a client with a JSON file and issue get_treatment() calls.""" + self._update_temp_file(splits_json['splitChange2_1']) + filename = os.path.join(os.path.dirname(__file__), 'files', 'split_changes_temp.json') + self.factory = await get_factory_async('localhost', config={'splitFile': filename}) + await self.factory.block_until_ready(1) + client = self.factory.client() + + # Tests 2 + assert await self.factory.manager().split_names() == ["SPLIT_1"] + assert await client.get_treatment("key", "SPLIT_1") == 'off' + + # Tests 1 + await self.factory._storages['splits'].update([], ['SPLIT_1'], -1) + self._update_temp_file(splits_json['splitChange1_1']) + await self._synchronize_now() + + assert sorted(await self.factory.manager().split_names()) == ["SPLIT_1", "SPLIT_2", "SPLIT_3"] + assert await client.get_treatment("key", "SPLIT_1", None) == 'off' + assert await client.get_treatment("key", "SPLIT_2", None) == 'on' + + self._update_temp_file(splits_json['splitChange1_2']) + await self._synchronize_now() + + assert sorted(await self.factory.manager().split_names()) == ["SPLIT_1", "SPLIT_2", "SPLIT_3"] + assert await client.get_treatment("key", "SPLIT_1", None) == 'off' + assert await client.get_treatment("key", "SPLIT_2", None) == 'off' + + self._update_temp_file(splits_json['splitChange1_3']) + await self._synchronize_now() + + assert sorted(await self.factory.manager().split_names()) == ["SPLIT_2", "SPLIT_3"] + assert await client.get_treatment("key", "SPLIT_1", None) == 'control' + assert await client.get_treatment("key", "SPLIT_2", None) == 'on' + + # Tests 3 + await self.factory._storages['splits'].update([], ['SPLIT_1'], -1) + self._update_temp_file(splits_json['splitChange3_1']) + await self._synchronize_now() + + assert sorted(await self.factory.manager().split_names()) == ["SPLIT_2", "SPLIT_3"] + assert await client.get_treatment("key", "SPLIT_2", None) == 'on' + + self._update_temp_file(splits_json['splitChange3_2']) + await self._synchronize_now() + + assert sorted(await self.factory.manager().split_names()) == ["SPLIT_2", "SPLIT_3"] + assert await client.get_treatment("key", "SPLIT_2", None) == 'off' + + # Tests 4 + await self.factory._storages['splits'].update([], ['SPLIT_2'], -1) + self._update_temp_file(splits_json['splitChange4_1']) + await self._synchronize_now() + + assert sorted(await self.factory.manager().split_names()) == ["SPLIT_1", "SPLIT_2", "SPLIT_3"] + assert await client.get_treatment("key", "SPLIT_1", None) == 'off' + assert await client.get_treatment("key", "SPLIT_2", None) == 'on' + + self._update_temp_file(splits_json['splitChange4_2']) + await self._synchronize_now() + + assert sorted(await self.factory.manager().split_names()) == ["SPLIT_1", "SPLIT_2", "SPLIT_3"] + assert await client.get_treatment("key", "SPLIT_1", None) == 'off' + assert await client.get_treatment("key", "SPLIT_2", None) == 'off' + + self._update_temp_file(splits_json['splitChange4_3']) + await self._synchronize_now() + + assert sorted(await self.factory.manager().split_names()) == ["SPLIT_2", "SPLIT_3"] + assert await client.get_treatment("key", "SPLIT_1", None) == 'control' + assert await client.get_treatment("key", "SPLIT_2", None) == 'on' + + # Tests 5 + await self.factory._storages['splits'].update([], ['SPLIT_1', 'SPLIT_2'], -1) + self._update_temp_file(splits_json['splitChange5_1']) + await self._synchronize_now() + + assert sorted(await self.factory.manager().split_names()) == ["SPLIT_2", "SPLIT_3"] + assert await client.get_treatment("key", "SPLIT_2", None) == 'on' + + self._update_temp_file(splits_json['splitChange5_2']) + await self._synchronize_now() + + assert sorted(await self.factory.manager().split_names()) == ["SPLIT_2", "SPLIT_3"] + assert await client.get_treatment("key", "SPLIT_2", None) == 'on' + + # Tests 6 + await self.factory._storages['splits'].update([], ['SPLIT_2'], -1) + self._update_temp_file(splits_json['splitChange6_1']) + await self._synchronize_now() + + assert sorted(await self.factory.manager().split_names()) == ["SPLIT_1", "SPLIT_2", "SPLIT_3"] + assert await client.get_treatment("key", "SPLIT_1", None) == 'off' + assert await client.get_treatment("key", "SPLIT_2", None) == 'on' + + self._update_temp_file(splits_json['splitChange6_2']) + await self._synchronize_now() + + assert sorted(await self.factory.manager().split_names()) == ["SPLIT_1", "SPLIT_2", "SPLIT_3"] + assert await client.get_treatment("key", "SPLIT_1", None) == 'off' + assert await client.get_treatment("key", "SPLIT_2", None) == 'off' + + self._update_temp_file(splits_json['splitChange6_3']) + await self._synchronize_now() + + assert sorted(await self.factory.manager().split_names()) == ["SPLIT_2", "SPLIT_3"] + assert await client.get_treatment("key", "SPLIT_1", None) == 'control' + assert await client.get_treatment("key", "SPLIT_2", None) == 'on' + + # rule based segment test + self._update_temp_file(splits_json['splitChange7_1']) + await self._synchronize_now() + assert await client.get_treatment('bilal@split.io', 'rbs_feature_flag', {'email': 'bilal@split.io'}) == 'on' + assert await client.get_treatment('mauro@split.io', 'rbs_feature_flag', {'email': 'mauro@split.io'}) == 'off' + + def _update_temp_file(self, json_body): + f = open(os.path.join(os.path.dirname(__file__), 'files','split_changes_temp.json'), 'w') + f.write(json.dumps(json_body)) + f.close() + + async def _synchronize_now(self): + filename = os.path.join(os.path.dirname(__file__), 'files', 'split_changes_temp.json') + self.factory._sync_manager._synchronizer._split_synchronizers._feature_flag_sync._filename = filename + await self.factory._sync_manager._synchronizer._split_synchronizers._feature_flag_sync.synchronize_splits() + + @pytest.mark.asyncio + async def test_incorrect_file_e2e(self): + """Test initialize factory with a incorrect file name.""" + # TODO: secontion below is removed when legacu use BUR + # legacy and yaml + exception_raised = False + factory = None + try: + factory = await get_factory_async('localhost', config={'splitFile': 'filename'}) + except Exception as e: + exception_raised = True + + assert(exception_raised) + + # json using BUR + factory = await get_factory_async('localhost', config={'splitFile': 'filename.json'}) + exception_raised = False + try: + await factory.block_until_ready(1) + except Exception as e: + exception_raised = True + + assert(exception_raised) + await factory.destroy() + + @pytest.mark.asyncio + async def test_localhost_e2e(self): + """Instantiate a client with a YAML file and issue get_treatment() calls.""" + filename = os.path.join(os.path.dirname(__file__), 'files', 'file2.yaml') + factory = await get_factory_async('localhost', config={'splitFile': filename}) + await factory.block_until_ready() + client = factory.client() + assert await client.get_treatment_with_config('key', 'my_feature') == ('on', '{"desc" : "this applies only to ON treatment"}') + assert await client.get_treatment_with_config('only_key', 'my_feature') == ( + 'off', '{"desc" : "this applies only to OFF and only for only_key. The rest will receive ON"}' + ) + assert await client.get_treatment_with_config('another_key', 'my_feature') == ('control', None) + assert await client.get_treatment_with_config('key2', 'other_feature') == ('on', None) + assert await client.get_treatment_with_config('key3', 'other_feature') == ('on', None) + assert await client.get_treatment_with_config('some_key', 'other_feature_2') == ('on', None) + assert await client.get_treatment_with_config('key_whitelist', 'other_feature_3') == ('on', None) + assert await client.get_treatment_with_config('any_other_key', 'other_feature_3') == ('off', None) + + manager = factory.manager() + split = await manager.split('my_feature') + assert split.configs == { + 'on': '{"desc" : "this applies only to ON treatment"}', + 'off': '{"desc" : "this applies only to OFF and only for only_key. The rest will receive ON"}' + } + split = await manager.split('other_feature') + assert split.configs == {} + split = await manager.split('other_feature_2') + assert split.configs == {} + split = await manager.split('other_feature_3') + assert split.configs == {} + await factory.destroy() + + @pytest.mark.asyncio + async def test_fallback_treatments(self): + """Instantiate a client with a JSON file and issue get_treatment() calls.""" + self._update_temp_file(splits_json['splitChange2_1']) + filename = os.path.join(os.path.dirname(__file__), 'files', 'split_changes_temp.json') + factory = await get_factory_async('localhost', + config={ + 'splitFile': filename, + 'fallbackTreatments': FallbackTreatmentsConfiguration("on-global", {'fallback_feature': "on-local"}) + } + ) + await factory.block_until_ready(1) + client = factory.client() + + assert await client.get_treatment("key", "feature") == "on-global" + assert await client.get_treatment("key", "fallback_feature") == "on-local" + await factory.destroy() + +class PluggableIntegrationAsyncTests(object): + """Pluggable storage-based integration tests.""" + def setup_method(self): + self.setup_task = asyncio.get_event_loop().create_task(self._setup_method()) + + async def _setup_method(self): + """Prepare storages with test data.""" + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + self.pluggable_storage_adapter = StorageMockAdapterAsync() + split_storage = PluggableSplitStorageAsync(self.pluggable_storage_adapter, 'myprefix') + segment_storage = PluggableSegmentStorageAsync(self.pluggable_storage_adapter, 'myprefix') + rb_segment_storage = PluggableRuleBasedSegmentsStorageAsync(self.pluggable_storage_adapter, 'myprefix') + + telemetry_pluggable_storage = await PluggableTelemetryStorageAsync.create(self.pluggable_storage_adapter, metadata, 'myprefix') + telemetry_producer = TelemetryStorageProducerAsync(telemetry_pluggable_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_submitter = RedisTelemetrySubmitterAsync(telemetry_pluggable_storage) + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': PluggableImpressionsStorageAsync(self.pluggable_storage_adapter, metadata), + 'events': PluggableEventsStorageAsync(self.pluggable_storage_adapter, metadata), + 'telemetry': telemetry_pluggable_storage + } + + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorderAsync(impmanager, storages['events'], + storages['impressions'], + telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_runtime_producer, imp_counter=ImpressionsCounter()) + + self.factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + internal_events_queue, + events_manager, + RedisManagerAsync(PluggableSynchronizerAsync()), + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + telemetry_submitter=telemetry_submitter, + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(self.factory).ready = ready_property + + # Adding data to storage + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + await self.pluggable_storage_adapter.set(split_storage._prefix.format(feature_flag_name=split['name']), split) + for flag_set in split.get('sets'): + await self.pluggable_storage_adapter.push_items(split_storage._flag_set_prefix.format(flag_set=flag_set), split['name']) + await self.pluggable_storage_adapter.set(split_storage._feature_flag_till_prefix, data['ff']['d']) + + for rbs in data['rbs']['d']: + await self.pluggable_storage_adapter.set(rb_segment_storage._prefix.format(segment_name=rbs['name']), rbs) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await self.pluggable_storage_adapter.set(segment_storage._prefix.format(segment_name=data['name']), set(data['added'])) + await self.pluggable_storage_adapter.set(segment_storage._segment_till_prefix.format(segment_name=data['name']), data['till']) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await self.pluggable_storage_adapter.set(segment_storage._prefix.format(segment_name=data['name']), set(data['added'])) + await self.pluggable_storage_adapter.set(segment_storage._segment_till_prefix.format(segment_name=data['name']), data['till']) + await self.factory.block_until_ready(1) + + @pytest.mark.asyncio + async def test_get_treatment(self): + """Test client.get_treatment().""" + await self.setup_task + await _get_treatment_async(self.factory) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatment_with_config(self): + """Test client.get_treatment_with_config().""" + await self.setup_task + await _get_treatment_with_config_async(self.factory) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments(self): + # testing multiple splitNames + await self.setup_task + await _get_treatments_async(self.factory) + client = self.factory.client() + result = await client.get_treatments('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == 'on' + assert result['killed_feature'] == 'defTreatment' + assert result['invalid_feature'] == 'control' + assert result['sample_feature'] == 'off' + await _validate_last_impressions_async( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('sample_feature', 'invalidKey', 'off') + ) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_with_config(self): + """Test client.get_treatments_with_config().""" + await self.setup_task + await _get_treatments_with_config_async(self.factory) + # testing multiple splitNames + client = self.factory.client() + result = await client.get_treatments_with_config('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == ('on', None) + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + assert result['invalid_feature'] == ('control', None) + assert result['sample_feature'] == ('off', None) + await _validate_last_impressions_async( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('sample_feature', 'invalidKey', 'off'), + ) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_set(self): + """Test client.get_treatments_by_flag_set().""" + await self.setup_task + await _get_treatments_by_flag_set_async(self.factory) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_sets(self): + """Test client.get_treatments_by_flag_sets().""" + await self.setup_task + await _get_treatments_by_flag_sets_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': 'on', + 'whitelist_feature': 'off', + 'all_feature': 'on' + } + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_set(self): + """Test client.get_treatments_with_config_by_flag_set().""" + await self.setup_task + await _get_treatments_with_config_by_flag_set_async(self.factory) + await self.factory.destroy() + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_sets(self): + """Test client.get_treatments_with_config_by_flag_sets().""" + await self.setup_task + await _get_treatments_with_config_by_flag_sets_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_with_config_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), + 'whitelist_feature': ('off', None), + 'all_feature': ('on', None) + } + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on'), + ('whitelist_feature', 'user1', 'off'), + ('all_feature', 'user1', 'on') + ) + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_track(self): + """Test client.track().""" + await self.setup_task + await _track_async(self.factory) + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_manager_methods(self): + """Test manager.split/splits.""" + await self.setup_task + await _manager_methods_async(self.factory) + await self.factory.destroy() + await self._teardown_method() + + async def _teardown_method(self): + """Clear pluggable cache.""" + keys_to_delete = [ + "SPLITIO.segment.human_beigns", + "SPLITIO.segment.employees.till", + "SPLITIO.split.sample_feature", + "SPLITIO.splits.till", + "SPLITIO.split.killed_feature", + "SPLITIO.split.all_feature", + "SPLITIO.split.whitelist_feature", + "SPLITIO.segment.employees", + "SPLITIO.split.regex_test", + "SPLITIO.segment.human_beigns.till", + "SPLITIO.split.boolean_test", + "SPLITIO.split.dependency_test", + "SPLITIO.split.rbs_feature_flag", + "SPLITIO.rbsegments.till", + "SPLITIO.rbsegments.sample_rule_based_segment" + ] + + for key in keys_to_delete: + await self.pluggable_storage_adapter.delete(key) + + +class PluggableOptimizedIntegrationAsyncTests(object): + """Pluggable storage-based optimized integration tests.""" + def setup_method(self): + self.setup_task = asyncio.get_event_loop().create_task(self._setup_method()) + + async def _setup_method(self): + """Prepare storages with test data.""" + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + self.pluggable_storage_adapter = StorageMockAdapterAsync() + split_storage = PluggableSplitStorageAsync(self.pluggable_storage_adapter) + segment_storage = PluggableSegmentStorageAsync(self.pluggable_storage_adapter) + rb_segment_storage = PluggableRuleBasedSegmentsStorageAsync(self.pluggable_storage_adapter, 'myprefix') + + telemetry_pluggable_storage = await PluggableTelemetryStorageAsync.create(self.pluggable_storage_adapter, metadata) + telemetry_producer = TelemetryStorageProducerAsync(telemetry_pluggable_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_submitter = RedisTelemetrySubmitterAsync(telemetry_pluggable_storage) + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': PluggableImpressionsStorageAsync(self.pluggable_storage_adapter, metadata), + 'events': PluggableEventsStorageAsync(self.pluggable_storage_adapter, metadata), + 'telemetry': telemetry_pluggable_storage + } + + impmanager = ImpressionsManager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorderAsync(impmanager, storages['events'], + storages['impressions'], + telemetry_producer.get_telemetry_evaluation_producer(), + telemetry_runtime_producer, + imp_counter=ImpressionsCounter()) + + self.factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + internal_events_queue, + events_manager, + RedisManagerAsync(PluggableSynchronizerAsync()), + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + telemetry_submitter=telemetry_submitter, + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(self.factory).ready = ready_property + + # Adding data to storage + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + await self.pluggable_storage_adapter.set(split_storage._prefix.format(feature_flag_name=split['name']), split) + for flag_set in split.get('sets'): + await self.pluggable_storage_adapter.push_items(split_storage._flag_set_prefix.format(flag_set=flag_set), split['name']) + await self.pluggable_storage_adapter.set(split_storage._feature_flag_till_prefix, data['ff']['t']) + + for rbs in data['rbs']['d']: + await self.pluggable_storage_adapter.set(rb_segment_storage._prefix.format(segment_name=rbs['name']), rbs) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await self.pluggable_storage_adapter.set(segment_storage._prefix.format(segment_name=data['name']), set(data['added'])) + await self.pluggable_storage_adapter.set(segment_storage._segment_till_prefix.format(segment_name=data['name']), data['till']) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await self.pluggable_storage_adapter.set(segment_storage._prefix.format(segment_name=data['name']), set(data['added'])) + await self.pluggable_storage_adapter.set(segment_storage._segment_till_prefix.format(segment_name=data['name']), data['till']) + await self.factory.block_until_ready(1) + + @pytest.mark.asyncio + async def test_get_treatment(self): + """Test client.get_treatment().""" + await self.setup_task + await _get_treatment_async(self.factory) + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_get_treatments(self): + """Test client.get_treatments().""" + await self.setup_task + await _get_treatments_async(self.factory) + # testing multiple splitNames + client = self.factory.client() + result = await client.get_treatments('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == 'on' + assert result['killed_feature'] == 'defTreatment' + assert result['invalid_feature'] == 'control' + assert result['sample_feature'] == 'off' + assert len(self.pluggable_storage_adapter._keys['SPLITIO.impressions']) == 0 + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_get_treatments_with_config(self): + """Test client.get_treatments_with_config().""" + await self.setup_task + await _get_treatments_with_config_async(self.factory) + # testing multiple splitNames + client = self.factory.client() + result = await client.get_treatments_with_config('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == ('on', None) + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + assert result['invalid_feature'] == ('control', None) + assert result['sample_feature'] == ('off', None) + await _validate_last_impressions_async(client,) + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_set(self): + """Test client.get_treatments_by_flag_set().""" + await self.setup_task + await _get_treatments_by_flag_set_async(self.factory) + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_sets(self): + """Test client.get_treatments_by_flag_sets().""" + await self.setup_task + await _get_treatments_by_flag_sets_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': 'on', + 'whitelist_feature': 'off', + 'all_feature': 'on' + } + await _validate_last_impressions_async(client, ) + assert self.pluggable_storage_adapter._keys.get('SPLITIO.impressions') == None + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_set(self): + """Test client.get_treatments_with_config_by_flag_set().""" + await self.setup_task + await _get_treatments_with_config_by_flag_set_async(self.factory) + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_sets(self): + """Test client.get_treatments_with_config_by_flag_sets().""" + await self.setup_task + await _get_treatments_with_config_by_flag_sets_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_with_config_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), + 'whitelist_feature': ('off', None), + 'all_feature': ('on', None) + } + await _validate_last_impressions_async(client, ) + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_manager_methods(self): + """Test manager.split/splits.""" + await self.setup_task + await _manager_methods_async(self.factory) + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_track(self): + """Test client.track().""" + await self.setup_task + await _track_async(self.factory) + await self.factory.destroy() + await self._teardown_method() + + async def _teardown_method(self): + """Clear pluggable cache.""" + keys_to_delete = [ + "SPLITIO.segment.human_beigns", + "SPLITIO.segment.employees.till", + "SPLITIO.split.sample_feature", + "SPLITIO.splits.till", + "SPLITIO.split.killed_feature", + "SPLITIO.split.all_feature", + "SPLITIO.split.whitelist_feature", + "SPLITIO.segment.employees", + "SPLITIO.split.regex_test", + "SPLITIO.segment.human_beigns.till", + "SPLITIO.split.boolean_test", + "SPLITIO.split.dependency_test", + "SPLITIO.split.rbs_feature_flag", + "SPLITIO.rbsegments.till", + "SPLITIO.rbsegments.sample_rule_based_segment" + ] + + for key in keys_to_delete: + await self.pluggable_storage_adapter.delete(key) + +class PluggableNoneIntegrationAsyncTests(object): + """Pluggable storage-based integration tests.""" + + def setup_method(self): + self.setup_task = asyncio.get_event_loop().create_task(self._setup_method()) + + async def _setup_method(self): + """Prepare storages with test data.""" + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + self.pluggable_storage_adapter = StorageMockAdapterAsync() + split_storage = PluggableSplitStorageAsync(self.pluggable_storage_adapter) + segment_storage = PluggableSegmentStorageAsync(self.pluggable_storage_adapter) + rb_segment_storage = PluggableRuleBasedSegmentsStorageAsync(self.pluggable_storage_adapter, 'myprefix') + + telemetry_pluggable_storage = await PluggableTelemetryStorageAsync.create(self.pluggable_storage_adapter, metadata) + telemetry_producer = TelemetryStorageProducerAsync(telemetry_pluggable_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': PluggableImpressionsStorageAsync(self.pluggable_storage_adapter, metadata), + 'events': PluggableEventsStorageAsync(self.pluggable_storage_adapter, metadata), + 'telemetry': telemetry_pluggable_storage + } + imp_counter = ImpressionsCounter() + unique_keys_tracker = UniqueKeysTrackerAsync() + unique_keys_synchronizer, clear_filter_sync, self.unique_keys_task, \ + clear_filter_task, impressions_count_sync, impressions_count_task, \ + imp_strategy, none_strategy = set_classes_async('PLUGGABLE', ImpressionsMode.NONE, self.pluggable_storage_adapter, imp_counter, unique_keys_tracker) + impmanager = ImpressionsManager(imp_strategy, none_strategy, telemetry_runtime_producer) # no listener + + recorder = StandardRecorderAsync(impmanager, storages['events'], + storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, unique_keys_tracker=unique_keys_tracker, imp_counter=imp_counter) + + synchronizers = SplitSynchronizers(None, None, None, None, + impressions_count_sync, + None, + unique_keys_synchronizer, + clear_filter_sync + ) + + tasks = SplitTasks(None, None, None, None, + impressions_count_task, + None, + self.unique_keys_task, + clear_filter_task + ) + + synchronizer = RedisSynchronizerAsync(synchronizers, tasks) + + manager = RedisManagerAsync(synchronizer) + manager.start() + self.factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + internal_events_queue, + events_manager, + manager, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(FallbackTreatmentsConfiguration(None, {'fallback_feature': FallbackTreatment("on-local", '{"prop": "val"}')})) + ) # pylint:disable=attribute-defined-outside-init + + # Adding data to storage + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + await self.pluggable_storage_adapter.set(split_storage._prefix.format(feature_flag_name=split['name']), split) + for flag_set in split.get('sets'): + await self.pluggable_storage_adapter.push_items(split_storage._flag_set_prefix.format(flag_set=flag_set), split['name']) + await self.pluggable_storage_adapter.set(split_storage._feature_flag_till_prefix, data['ff']['t']) + + for rbs in data['rbs']['d']: + await self.pluggable_storage_adapter.set(rb_segment_storage._prefix.format(segment_name=rbs['name']), rbs) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await self.pluggable_storage_adapter.set(segment_storage._prefix.format(segment_name=data['name']), set(data['added'])) + await self.pluggable_storage_adapter.set(segment_storage._segment_till_prefix.format(segment_name=data['name']), data['till']) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + await self.pluggable_storage_adapter.set(segment_storage._prefix.format(segment_name=data['name']), set(data['added'])) + await self.pluggable_storage_adapter.set(segment_storage._segment_till_prefix.format(segment_name=data['name']), data['till']) + await self.factory.block_until_ready(1) + + @pytest.mark.asyncio + async def test_get_treatment(self): + """Test client.get_treatment().""" + await self.setup_task + await _get_treatment_async(self.factory) + assert self.pluggable_storage_adapter._keys['SPLITIO.impressions'] == [] + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_get_treatments(self): + """Test client.get_treatments().""" + await self.setup_task + await _get_treatments_async(self.factory) + client = self.factory.client() + result = await client.get_treatments('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == 'on' + assert result['killed_feature'] == 'defTreatment' + assert result['invalid_feature'] == 'control' + assert result['sample_feature'] == 'off' + assert self.pluggable_storage_adapter._keys['SPLITIO.impressions'] == [] + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_get_treatments_with_config(self): + """Test client.get_treatments_with_config().""" + await self.setup_task + await _get_treatments_with_config_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_with_config('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == ('on', None) + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + assert result['invalid_feature'] == ('control', None) + assert result['sample_feature'] == ('off', None) + assert self.pluggable_storage_adapter._keys['SPLITIO.impressions'] == [] + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_set(self): + """Test client.get_treatments_by_flag_set().""" + await self.setup_task + await _get_treatments_by_flag_set_async(self.factory) + assert self.pluggable_storage_adapter._keys['SPLITIO.impressions'] == [] + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_get_treatments_by_flag_sets(self): + """Test client.get_treatments_by_flag_sets().""" + await self.setup_task + await _get_treatments_by_flag_sets_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': 'on', + 'whitelist_feature': 'off', + 'all_feature': 'on' + } + assert self.pluggable_storage_adapter._keys['SPLITIO.impressions'] == [] + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_set(self): + """Test client.get_treatments_with_config_by_flag_set().""" + await self.setup_task + await _get_treatments_with_config_by_flag_set_async(self.factory) + assert self.pluggable_storage_adapter._keys['SPLITIO.impressions'] == [] + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_get_treatments_with_config_by_flag_sets(self): + """Test client.get_treatments_with_config_by_flag_sets().""" + await self.setup_task + await _get_treatments_with_config_by_flag_sets_async(self.factory) + client = self.factory.client() + result = await client.get_treatments_with_config_by_flag_sets('user1', ['set1', 'set2', 'set4']) + assert len(result) == 3 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), + 'whitelist_feature': ('off', None), + 'all_feature': ('on', None) + } + assert self.pluggable_storage_adapter._keys['SPLITIO.impressions'] == [] + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_track(self): + """Test client.track().""" + await self.setup_task + await _track_async(self.factory) + await self.factory.destroy() + await self._teardown_method() + + @pytest.mark.asyncio + async def test_mtk(self): + await self.setup_task + client = self.factory.client() + await client.get_treatment('user1', 'sample_feature') + await client.get_treatment('invalidKey', 'sample_feature') + await client.get_treatment('invalidKey2', 'sample_feature') + await client.get_treatment('user22', 'invalidFeature') + self.unique_keys_task._task.force_execution() + await asyncio.sleep(1) + + assert(json.loads(self.pluggable_storage_adapter._keys['SPLITIO.uniquekeys'][0])["f"] =="sample_feature") + assert(json.loads(self.pluggable_storage_adapter._keys['SPLITIO.uniquekeys'][0])["ks"].sort() == + ["invalidKey2", "invalidKey", "user1"].sort()) + await self.factory.destroy() + await self._teardown_method() + + async def _teardown_method(self): + """Clear pluggable cache.""" + keys_to_delete = [ + "SPLITIO.segment.human_beigns", + "SPLITIO.segment.employees.till", + "SPLITIO.split.sample_feature", + "SPLITIO.splits.till", + "SPLITIO.split.killed_feature", + "SPLITIO.split.all_feature", + "SPLITIO.split.whitelist_feature", + "SPLITIO.segment.employees", + "SPLITIO.split.regex_test", + "SPLITIO.segment.human_beigns.till", + "SPLITIO.split.boolean_test", + "SPLITIO.split.dependency_test", + "SPLITIO.split.rbs_feature_flag", + "SPLITIO.rbsegments.till", + "SPLITIO.rbsegments.sample_rule_based_segment" + ] + + for key in keys_to_delete: + await self.pluggable_storage_adapter.delete(key) + +class InMemoryImpressionsToggleIntegrationAsyncTests(object): + """InMemory storage-based impressions toggle integration tests.""" + + @pytest.mark.asyncio + async def test_optimized(self): + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + + await split_storage.update([splits.from_raw(splits_json['splitChange1_1']['ff']['d'][0]), + splits.from_raw(splits_json['splitChange1_1']['ff']['d'][1]), + splits.from_raw(splits_json['splitChange1_1']['ff']['d'][2]) + ], [], -1) + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': InMemoryRuleBasedSegmentStorageAsync(internal_events_queue), + 'impressions': InMemoryImpressionStorageAsync(5000, telemetry_runtime_producer), + 'events': InMemoryEventStorageAsync(5000, telemetry_runtime_producer), + } + impmanager = ImpressionsManager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorderAsync(impmanager, storages['events'], storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, None, UniqueKeysTrackerAsync(), ImpressionsCounter()) + # Since we are passing None as SDK_Ready event, the factory will use the Redis telemetry call, using try catch to ignore the exception. + try: + factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + internal_events_queue, + events_manager, + None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(None) + ) # pylint:disable=attribute-defined-outside-init + except: + pass + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(factory).ready = ready_property + + try: + client = factory.client() + except: + pass + + assert await client.get_treatment('user1', 'SPLIT_1') == 'off' + assert await client.get_treatment('user1', 'SPLIT_2') == 'on' + assert await client.get_treatment('user1', 'SPLIT_3') == 'on' + imp_storage = client._factory._get_storage('impressions') + impressions = await imp_storage.pop_many(10) + assert len(impressions) == 2 + assert impressions[0].feature_name == 'SPLIT_1' + assert impressions[1].feature_name == 'SPLIT_2' + assert client._recorder._unique_keys_tracker._cache == {'SPLIT_3': {'user1'}} + imps_count = client._recorder._imp_counter.pop_all() + assert len(imps_count) == 1 + assert imps_count[0].feature == 'SPLIT_3' + assert imps_count[0].count == 1 + await factory.destroy() + + @pytest.mark.asyncio + async def test_debug(self): + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + + await split_storage.update([splits.from_raw(splits_json['splitChange1_1']['ff']['d'][0]), + splits.from_raw(splits_json['splitChange1_1']['ff']['d'][1]), + splits.from_raw(splits_json['splitChange1_1']['ff']['d'][2]) + ], [], -1) + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': InMemoryRuleBasedSegmentStorageAsync(internal_events_queue), + 'impressions': InMemoryImpressionStorageAsync(5000, telemetry_runtime_producer), + 'events': InMemoryEventStorageAsync(5000, telemetry_runtime_producer), + } + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorderAsync(impmanager, storages['events'], storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, None, UniqueKeysTrackerAsync(), ImpressionsCounter()) + # Since we are passing None as SDK_Ready event, the factory will use the Redis telemetry call, using try catch to ignore the exception. + try: + factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + internal_events_queue, + events_manager, + None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(None) + ) # pylint:disable=attribute-defined-outside-init + except: + pass + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(factory).ready = ready_property + + try: + client = factory.client() + except: + pass + + assert await client.get_treatment('user1', 'SPLIT_1') == 'off' + assert await client.get_treatment('user1', 'SPLIT_2') == 'on' + assert await client.get_treatment('user1', 'SPLIT_3') == 'on' + imp_storage = client._factory._get_storage('impressions') + impressions = await imp_storage.pop_many(10) + assert len(impressions) == 2 + assert impressions[0].feature_name == 'SPLIT_1' + assert impressions[1].feature_name == 'SPLIT_2' + assert client._recorder._unique_keys_tracker._cache == {'SPLIT_3': {'user1'}} + imps_count = client._recorder._imp_counter.pop_all() + assert len(imps_count) == 1 + assert imps_count[0].feature == 'SPLIT_3' + assert imps_count[0].count == 1 + await factory.destroy() + + @pytest.mark.asyncio + async def test_none(self): + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + + await split_storage.update([splits.from_raw(splits_json['splitChange1_1']['ff']['d'][0]), + splits.from_raw(splits_json['splitChange1_1']['ff']['d'][1]), + splits.from_raw(splits_json['splitChange1_1']['ff']['d'][2]) + ], [], -1) + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': InMemoryRuleBasedSegmentStorageAsync(internal_events_queue), + 'impressions': InMemoryImpressionStorageAsync(5000, telemetry_runtime_producer), + 'events': InMemoryEventStorageAsync(5000, telemetry_runtime_producer), + } + impmanager = ImpressionsManager(StrategyNoneMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = StandardRecorderAsync(impmanager, storages['events'], storages['impressions'], telemetry_evaluation_producer, telemetry_runtime_producer, None, UniqueKeysTrackerAsync(), ImpressionsCounter()) + # Since we are passing None as SDK_Ready event, the factory will use the Redis telemetry call, using try catch to ignore the exception. + try: + factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + internal_events_queue, + events_manager, + None, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(None) + ) # pylint:disable=attribute-defined-outside-init + except: + pass + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(factory).ready = ready_property + + try: + client = factory.client() + except: + pass + + assert await client.get_treatment('user1', 'SPLIT_1') == 'off' + assert await client.get_treatment('user1', 'SPLIT_2') == 'on' + assert await client.get_treatment('user1', 'SPLIT_3') == 'on' + imp_storage = client._factory._get_storage('impressions') + impressions = await imp_storage.pop_many(10) + assert len(impressions) == 0 + assert client._recorder._unique_keys_tracker._cache == {'SPLIT_1': {'user1'}, 'SPLIT_2': {'user1'}, 'SPLIT_3': {'user1'}} + imps_count = client._recorder._imp_counter.pop_all() + assert len(imps_count) == 3 + assert imps_count[0].feature == 'SPLIT_1' + assert imps_count[0].count == 1 + assert imps_count[1].feature == 'SPLIT_2' + assert imps_count[1].count == 1 + assert imps_count[2].feature == 'SPLIT_3' + assert imps_count[2].count == 1 + await factory.destroy() + +class RedisImpressionsToggleIntegrationAsyncTests(object): + """Run impression toggle tests for Redis.""" + + @pytest.mark.asyncio + async def test_optimized(self): + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + """Prepare storages with test data.""" + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + redis_client = await build_async(DEFAULT_CONFIG.copy()) + split_storage = RedisSplitStorageAsync(redis_client, True) + segment_storage = RedisSegmentStorageAsync(redis_client) + rb_segment_storage = RedisRuleBasedSegmentsStorageAsync(redis_client) + + await redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][0]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][0])) + await redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][1]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][1])) + await redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][2]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][2])) + await redis_client.set(split_storage._FEATURE_FLAG_TILL_KEY, -1) + + telemetry_redis_storage = await RedisTelemetryStorageAsync.create(redis_client, metadata) + telemetry_producer = TelemetryStorageProducerAsync(telemetry_redis_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': RedisImpressionsStorageAsync(redis_client, metadata), + 'events': RedisEventsStorageAsync(redis_client, metadata), + } + impmanager = ImpressionsManager(StrategyOptimizedMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = PipelinedRecorderAsync(redis_client.pipeline, impmanager, + storages['events'], storages['impressions'], telemetry_redis_storage, unique_keys_tracker=UniqueKeysTracker(), imp_counter=ImpressionsCounter()) + factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + internal_events_queue, + events_manager, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(None) + ) # pylint:disable=attribute-defined-outside-init + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(factory).ready = ready_property + + try: + client = factory.client() + except: + pass + + assert await client.get_treatment('user1', 'SPLIT_1') == 'off' + assert await client.get_treatment('user2', 'SPLIT_2') == 'on' + assert await client.get_treatment('user3', 'SPLIT_3') == 'on' + await asyncio.sleep(0.2) + + imp_storage = factory._storages['impressions'] + impressions = [] + while True: + impression = await redis_client.lpop(imp_storage.IMPRESSIONS_QUEUE_KEY) + if impression is None: + break + impressions.append(json.loads(impression)) + + assert len(impressions) == 2 + assert impressions[0]['i']['f'] == 'SPLIT_1' + assert impressions[1]['i']['f'] == 'SPLIT_2' + assert client._recorder._unique_keys_tracker._cache == {'SPLIT_3': {'user3'}} + imps_count = client._recorder._imp_counter.pop_all() + assert len(imps_count) == 1 + assert imps_count[0].feature == 'SPLIT_3' + assert imps_count[0].count == 1 + await self.clear_cache() + await factory.destroy() + + @pytest.mark.asyncio + async def test_debug(self): + """Prepare storages with test data.""" + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + redis_client = await build_async(DEFAULT_CONFIG.copy()) + split_storage = RedisSplitStorageAsync(redis_client, True) + segment_storage = RedisSegmentStorageAsync(redis_client) + rb_segment_storage = RedisRuleBasedSegmentsStorageAsync(redis_client) + + await redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][0]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][0])) + await redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][1]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][1])) + await redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][2]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][2])) + await redis_client.set(split_storage._FEATURE_FLAG_TILL_KEY, -1) + + telemetry_redis_storage = await RedisTelemetryStorageAsync.create(redis_client, metadata) + telemetry_producer = TelemetryStorageProducerAsync(telemetry_redis_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': RedisImpressionsStorageAsync(redis_client, metadata), + 'events': RedisEventsStorageAsync(redis_client, metadata), + } + impmanager = ImpressionsManager(StrategyDebugMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = PipelinedRecorderAsync(redis_client.pipeline, impmanager, + storages['events'], storages['impressions'], telemetry_redis_storage, unique_keys_tracker=UniqueKeysTracker(), imp_counter=ImpressionsCounter()) + factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + internal_events_queue, + events_manager, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(None) + ) # pylint:disable=attribute-defined-outside-init + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(factory).ready = ready_property + + try: + client = factory.client() + except: + pass + + assert await client.get_treatment('user1', 'SPLIT_1') == 'off' + assert await client.get_treatment('user2', 'SPLIT_2') == 'on' + assert await client.get_treatment('user3', 'SPLIT_3') == 'on' + await asyncio.sleep(0.2) + + imp_storage = factory._storages['impressions'] + impressions = [] + while True: + impression = await redis_client.lpop(imp_storage.IMPRESSIONS_QUEUE_KEY) + if impression is None: + break + impressions.append(json.loads(impression)) + + assert len(impressions) == 2 + assert impressions[0]['i']['f'] == 'SPLIT_1' + assert impressions[1]['i']['f'] == 'SPLIT_2' + assert client._recorder._unique_keys_tracker._cache == {'SPLIT_3': {'user3'}} + imps_count = client._recorder._imp_counter.pop_all() + assert len(imps_count) == 1 + assert imps_count[0].feature == 'SPLIT_3' + assert imps_count[0].count == 1 + await self.clear_cache() + await factory.destroy() + + @pytest.mark.asyncio + async def test_none(self): + """Prepare storages with test data.""" + internal_events_queue = asyncio.Queue() + events_manager = EventsManagerAsync(EventsManagerConfig(), EventsDelivery()) + + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + redis_client = await build_async(DEFAULT_CONFIG.copy()) + split_storage = RedisSplitStorageAsync(redis_client, True) + segment_storage = RedisSegmentStorageAsync(redis_client) + rb_segment_storage = RedisRuleBasedSegmentsStorageAsync(redis_client) + + await redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][0]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][0])) + await redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][1]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][1])) + await redis_client.set(split_storage._get_key(splits_json['splitChange1_1']['ff']['d'][2]['name']), json.dumps(splits_json['splitChange1_1']['ff']['d'][2])) + await redis_client.set(split_storage._FEATURE_FLAG_TILL_KEY, -1) + + telemetry_redis_storage = await RedisTelemetryStorageAsync.create(redis_client, metadata) + telemetry_producer = TelemetryStorageProducerAsync(telemetry_redis_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + + storages = { + 'splits': split_storage, + 'segments': segment_storage, + 'rule_based_segments': rb_segment_storage, + 'impressions': RedisImpressionsStorageAsync(redis_client, metadata), + 'events': RedisEventsStorageAsync(redis_client, metadata), + } + impmanager = ImpressionsManager(StrategyNoneMode(), StrategyNoneMode(), telemetry_runtime_producer) # no listener + recorder = PipelinedRecorderAsync(redis_client.pipeline, impmanager, + storages['events'], storages['impressions'], telemetry_redis_storage, unique_keys_tracker=UniqueKeysTracker(), imp_counter=ImpressionsCounter()) + factory = SplitFactoryAsync('some_api_key', + storages, + True, + recorder, + internal_events_queue, + events_manager, + telemetry_producer=telemetry_producer, + telemetry_init_producer=telemetry_producer.get_telemetry_init_producer(), + fallback_treatment_calculator=FallbackTreatmentCalculator(None) + ) # pylint:disable=attribute-defined-outside-init + ready_property = mocker.PropertyMock() + ready_property.return_value = True + type(factory).ready = ready_property + + try: + client = factory.client() + except: + pass + + assert await client.get_treatment('user1', 'SPLIT_1') == 'off' + assert await client.get_treatment('user2', 'SPLIT_2') == 'on' + assert await client.get_treatment('user3', 'SPLIT_3') == 'on' + await asyncio.sleep(0.2) + + imp_storage = factory._storages['impressions'] + impressions = [] + while True: + impression = await redis_client.lpop(imp_storage.IMPRESSIONS_QUEUE_KEY) + if impression is None: + break + impressions.append(json.loads(impression)) + + assert len(impressions) == 0 + assert client._recorder._unique_keys_tracker._cache == {'SPLIT_1': {'user1'}, 'SPLIT_2': {'user2'}, 'SPLIT_3': {'user3'}} + imps_count = client._recorder._imp_counter.pop_all() + assert len(imps_count) == 3 + assert imps_count[0].feature == 'SPLIT_1' + assert imps_count[0].count == 1 + assert imps_count[1].feature == 'SPLIT_2' + assert imps_count[1].count == 1 + assert imps_count[2].feature == 'SPLIT_3' + assert imps_count[2].count == 1 + await self.clear_cache() + await factory.destroy() + + async def clear_cache(self): + """Clear redis cache.""" + keys_to_delete = [ + "SPLITIO.split.SPLIT_3", + "SPLITIO.splits.till", + "SPLITIO.split.SPLIT_2", + "SPLITIO.split.SPLIT_1", + "SPLITIO.telemetry.latencies" + ] + + redis_client = await build_async(DEFAULT_CONFIG.copy()) + for key in keys_to_delete: + await redis_client.delete(key) + +async def _validate_last_impressions_async(client, *to_validate): + """Validate the last N impressions are present disregarding the order.""" + imp_storage = client._factory._get_storage('impressions') + if isinstance(client._factory._get_storage('splits'), RedisSplitStorageAsync) or isinstance(client._factory._get_storage('splits'), PluggableSplitStorageAsync): + if isinstance(client._factory._get_storage('splits'), RedisSplitStorageAsync): + redis_client = imp_storage._redis + impressions_raw = [ + json.loads(await redis_client.lpop(imp_storage.IMPRESSIONS_QUEUE_KEY)) + for _ in to_validate + ] + else: + pluggable_adapter = imp_storage._pluggable_adapter + results = await pluggable_adapter.pop_items(imp_storage._impressions_queue_key) + results = [] if results == None else results + impressions_raw = [ + json.loads(i) + for i in results + ] + as_tup_set = set( + (i['i']['f'], i['i']['k'], i['i']['t']) + for i in impressions_raw + ) + assert as_tup_set == set(to_validate) + await asyncio.sleep(0.2) # delay for redis to sync + else: + impressions = await imp_storage.pop_many(len(to_validate)) + as_tup_set = set((i.feature_name, i.matching_key, i.treatment) for i in impressions) + assert as_tup_set == set(to_validate) + +async def _validate_last_events_async(client, *to_validate): + """Validate the last N impressions are present disregarding the order.""" + event_storage = client._factory._get_storage('events') + if isinstance(client._factory._get_storage('splits'), RedisSplitStorageAsync) or isinstance(client._factory._get_storage('splits'), PluggableSplitStorageAsync): + if isinstance(client._factory._get_storage('splits'), RedisSplitStorageAsync): + redis_client = event_storage._redis + events_raw = [ + json.loads(await redis_client.lpop(event_storage._EVENTS_KEY_TEMPLATE)) + for _ in to_validate + ] + else: + pluggable_adapter = event_storage._pluggable_adapter + events_raw = [ + json.loads(i) + for i in await pluggable_adapter.pop_items(event_storage._events_queue_key) + ] + as_tup_set = set( + (i['e']['key'], i['e']['trafficTypeName'], i['e']['eventTypeId'], i['e']['value'], str(i['e']['properties'])) + for i in events_raw + ) + assert as_tup_set == set(to_validate) + else: + events = await event_storage.pop_many(len(to_validate)) + as_tup_set = set((i.key, i.traffic_type_name, i.event_type_id, i.value, str(i.properties)) for i in events) + assert as_tup_set == set(to_validate) + +async def _get_treatment_async(factory, skip_rbs=False): + """Test client.get_treatment().""" + try: + client = factory.client() + except: + pass + + assert await client.get_treatment('user1', 'sample_feature') == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on')) + + assert await client.get_treatment('invalidKey', 'sample_feature') == 'off' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('sample_feature', 'invalidKey', 'off')) + + assert await client.get_treatment('invalidKey', 'invalid_feature') == 'control' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client) # No impressions should be present + + # testing a killed feature. No matter what the key, must return default treatment + assert await client.get_treatment('invalidKey', 'killed_feature') == 'defTreatment' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + assert await client.get_treatment('invalidKey', 'all_feature') == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('all_feature', 'invalidKey', 'on')) + + # testing WHITELIST matcher + assert await client.get_treatment('whitelisted_user', 'whitelist_feature') == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('whitelist_feature', 'whitelisted_user', 'on')) + assert await client.get_treatment('unwhitelisted_user', 'whitelist_feature') == 'off' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('whitelist_feature', 'unwhitelisted_user', 'off')) + + # testing INVALID matcher + assert await client.get_treatment('some_user_key', 'invalid_matcher_feature') == 'control' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client) # No impressions should be present + + # testing Dependency matcher + assert await client.get_treatment('somekey', 'dependency_test') == 'off' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('dependency_test', 'somekey', 'off')) + + # testing boolean matcher + assert await client.get_treatment('True', 'boolean_test') == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('boolean_test', 'True', 'on')) + + # testing regex matcher + assert await client.get_treatment('abc4', 'regex_test') == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('regex_test', 'abc4', 'on')) + + # test fallback treatment + assert await client.get_treatment('user4321', 'fallback_feature') == 'on-local' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client) # No impressions should be present + + if skip_rbs: + return + + # test prerequisites matcher + assert await client.get_treatment('abc4', 'prereq_feature') == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('prereq_feature', 'abc4', 'on')) + + # test prerequisites matcher + assert await client.get_treatment('user1234', 'prereq_feature') == 'off_default' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('prereq_feature', 'user1234', 'off_default')) + + # test rule based segment matcher + assert await client.get_treatment('bilal@split.io', 'rbs_feature_flag', {'email': 'bilal@split.io'}) == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('rbs_feature_flag', 'bilal@split.io', 'on')) + + # test rule based segment matcher + assert await client.get_treatment('mauro@split.io', 'rbs_feature_flag', {'email': 'mauro@split.io'}) == 'off' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('rbs_feature_flag', 'mauro@split.io', 'off')) + +async def _get_treatment_with_config_async(factory): + """Test client.get_treatment_with_config().""" + try: + client = factory.client() + except: + pass + result = await client.get_treatment_with_config('user1', 'sample_feature') + assert result == ('on', '{"size":15,"test":20}') + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on')) + + result = await client.get_treatment_with_config('invalidKey', 'sample_feature') + assert result == ('off', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('sample_feature', 'invalidKey', 'off')) + + result = await client.get_treatment_with_config('invalidKey', 'invalid_feature') + assert result == ('control', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client) + + # testing a killed feature. No matter what the key, must return default treatment + result = await client.get_treatment_with_config('invalidKey', 'killed_feature') + assert ('defTreatment', '{"size":15,"defTreatment":true}') == result + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = await client.get_treatment_with_config('invalidKey', 'all_feature') + assert result == ('on', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('all_feature', 'invalidKey', 'on')) + + # test fallback treatment + assert await client.get_treatment_with_config('user4321', 'fallback_feature') == ('on-local', '{"prop": "val"}') + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client) # No impressions should be present + +async def _get_treatments_async(factory): + """Test client.get_treatments().""" + try: + client = factory.client() + except: + pass + result = await client.get_treatments('user1', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on')) + + result = await client.get_treatments('invalidKey', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == 'off' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('sample_feature', 'invalidKey', 'off')) + + result = await client.get_treatments('invalidKey', ['invalid_feature']) + assert len(result) == 1 + assert result['invalid_feature'] == 'control' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client) + + # testing a killed feature. No matter what the key, must return default treatment + result = await client.get_treatments('invalidKey', ['killed_feature']) + assert len(result) == 1 + assert result['killed_feature'] == 'defTreatment' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = await client.get_treatments('invalidKey', ['all_feature']) + assert len(result) == 1 + assert result['all_feature'] == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('all_feature', 'invalidKey', 'on')) + + # test fallback treatment + assert await client.get_treatments('user4321', ['fallback_feature']) == {'fallback_feature': 'on-local'} + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client) # No impressions should be present + +async def _get_treatments_with_config_async(factory): + """Test client.get_treatments_with_config().""" + try: + client = factory.client() + except: + pass + + result = await client.get_treatments_with_config('user1', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == ('on', '{"size":15,"test":20}') + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on')) + + result = await client.get_treatments_with_config('invalidKey', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == ('off', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('sample_feature', 'invalidKey', 'off')) + + result = await client.get_treatments_with_config('invalidKey', ['invalid_feature']) + assert len(result) == 1 + assert result['invalid_feature'] == ('control', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client) + + # testing a killed feature. No matter what the key, must return default treatment + result = await client.get_treatments_with_config('invalidKey', ['killed_feature']) + assert len(result) == 1 + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = await client.get_treatments_with_config('invalidKey', ['all_feature']) + assert len(result) == 1 + assert result['all_feature'] == ('on', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('all_feature', 'invalidKey', 'on')) + + # test fallback treatment + assert await client.get_treatments_with_config('user4321', ['fallback_feature']) == {'fallback_feature': ('on-local', '{"prop": "val"}')} + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client) # No impressions should be present + +async def _get_treatments_by_flag_set_async(factory): + """Test client.get_treatments_by_flag_set().""" + try: + client = factory.client() + except: + pass + result = await client.get_treatments_by_flag_set('user1', 'set1') + assert len(result) == 2 + assert result == {'sample_feature': 'on', 'whitelist_feature': 'off'} + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on'), ('whitelist_feature', 'user1', 'off')) + + result = await client.get_treatments_by_flag_set('invalidKey', 'invalid_set') + assert len(result) == 0 + assert result == {} + + # testing a killed feature. No matter what the key, must return default treatment + result = await client.get_treatments_by_flag_set('invalidKey', 'set3') + assert len(result) == 1 + assert result['killed_feature'] == 'defTreatment' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = await client.get_treatments_by_flag_set('invalidKey', 'set4') + assert len(result) == 1 + assert result['all_feature'] == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('all_feature', 'invalidKey', 'on')) + +async def _get_treatments_by_flag_sets_async(factory): + """Test client.get_treatments_by_flag_sets().""" + try: + client = factory.client() + except: + pass + result = await client.get_treatments_by_flag_sets('user1', ['set1']) + assert len(result) == 2 + assert result == {'sample_feature': 'on', 'whitelist_feature': 'off'} + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on'), ('whitelist_feature', 'user1', 'off')) + + result = await client.get_treatments_by_flag_sets('invalidKey', ['invalid_set']) + assert len(result) == 0 + assert result == {} + + result = await client.get_treatments_by_flag_sets('invalidKey', []) + assert len(result) == 0 + assert result == {} + + # testing a killed feature. No matter what the key, must return default treatment + result = await client.get_treatments_by_flag_sets('invalidKey', ['set3']) + assert len(result) == 1 + assert result['killed_feature'] == 'defTreatment' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = await client.get_treatments_by_flag_sets('user1', ['set4']) + assert len(result) == 1 + assert result['all_feature'] == 'on' + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('all_feature', 'user1', 'on')) + +async def _get_treatments_with_config_by_flag_set_async(factory): + """Test client.get_treatments_with_config_by_flag_set().""" + try: + client = factory.client() + except: + pass + result = await client.get_treatments_with_config_by_flag_set('user1', 'set1') + assert len(result) == 2 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), 'whitelist_feature': ('off', None)} + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on'), ('whitelist_feature', 'user1', 'off')) + + result = await client.get_treatments_with_config_by_flag_set('invalidKey', 'invalid_set') + assert len(result) == 0 + assert result == {} + + # testing a killed feature. No matter what the key, must return default treatment + result = await client.get_treatments_with_config_by_flag_set('invalidKey', 'set3') + assert len(result) == 1 + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = await client.get_treatments_with_config_by_flag_set('invalidKey', 'set4') + assert len(result) == 1 + assert result['all_feature'] == ('on', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('all_feature', 'invalidKey', 'on')) + +async def _get_treatments_with_config_by_flag_sets_async(factory): + """Test client.get_treatments_with_config_by_flag_sets().""" + try: + client = factory.client() + except: + pass + result = await client.get_treatments_with_config_by_flag_sets('user1', ['set1']) + assert len(result) == 2 + assert result == {'sample_feature': ('on', '{"size":15,"test":20}'), 'whitelist_feature': ('off', None)} + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('sample_feature', 'user1', 'on'), ('whitelist_feature', 'user1', 'off')) + + result = await client.get_treatments_with_config_by_flag_sets('invalidKey', ['invalid_set']) + assert len(result) == 0 + assert result == {} + + result = await client.get_treatments_with_config_by_flag_sets('invalidKey', []) + assert len(result) == 0 + assert result == {} + + # testing a killed feature. No matter what the key, must return default treatment + result = await client.get_treatments_with_config_by_flag_sets('invalidKey', ['set3']) + assert len(result) == 1 + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = await client.get_treatments_with_config_by_flag_sets('user1', ['set4']) + assert len(result) == 1 + assert result['all_feature'] == ('on', None) + if not isinstance(factory._recorder._impressions_manager._strategy, StrategyNoneMode): + await _validate_last_impressions_async(client, ('all_feature', 'user1', 'on')) + +async def _track_async(factory): + """Test client.track().""" + try: + client = factory.client() + except: + pass + assert(await client.track('user1', 'user', 'conversion', 1, {"prop1": "value1"})) + assert(not await client.track(None, 'user', 'conversion')) + assert(not await client.track('user1', None, 'conversion')) + assert(not await client.track('user1', 'user', None)) + await _validate_last_events_async( + client, + ('user1', 'user', 'conversion', 1, "{'prop1': 'value1'}") + ) + +async def _manager_methods_async(factory, skip_rbs=False): + """Test manager.split/splits.""" + try: + manager = factory.manager() + except: + pass + result = await manager.split('all_feature') + assert result.name == 'all_feature' + assert result.traffic_type is None + assert result.killed is False + assert len(result.treatments) == 2 + assert result.change_number == 123 + assert result.configs == {} + + result = await manager.split('killed_feature') + assert result.name == 'killed_feature' + assert result.traffic_type is None + assert result.killed is True + assert len(result.treatments) == 2 + assert result.change_number == 123 + assert result.configs['defTreatment'] == '{"size":15,"defTreatment":true}' + assert result.configs['off'] == '{"size":15,"test":20}' + + result = await manager.split('sample_feature') + assert result.name == 'sample_feature' + assert result.traffic_type is None + assert result.killed is False + assert len(result.treatments) == 2 + assert result.change_number == 123 + assert result.configs['on'] == '{"size":15,"test":20}' + + if skip_rbs: + assert len(await manager.split_names()) == 7 + assert len(await manager.splits()) == 7 + return + + assert len(await manager.split_names()) == 9 + assert len(await manager.splits()) == 9 \ No newline at end of file diff --git a/tests/integration/test_pluggable_integration.py b/tests/integration/test_pluggable_integration.py new file mode 100644 index 00000000..59534193 --- /dev/null +++ b/tests/integration/test_pluggable_integration.py @@ -0,0 +1,443 @@ +"""Pluggable storage end to end tests.""" +#pylint: disable=no-self-use,protected-access,line-too-long,too-few-public-methods +import pytest +import json +import os + +from splitio.client.util import get_metadata +from splitio.models import splits, impressions, events +from splitio.storage.pluggable import PluggableEventsStorage, PluggableImpressionsStorage, PluggableSegmentStorage, \ + PluggableSplitStorage, PluggableEventsStorageAsync, PluggableImpressionsStorageAsync, PluggableSegmentStorageAsync,\ + PluggableSplitStorageAsync +from splitio.client.config import DEFAULT_CONFIG +from tests.storage.test_pluggable import StorageMockAdapter, StorageMockAdapterAsync + +class PluggableSplitStorageIntegrationTests(object): + """Pluggable Split storage e2e tests.""" + + def test_put_fetch(self): + """Test storing and retrieving splits in pluggable.""" + adapter = StorageMockAdapter() + try: + storage = PluggableSplitStorage(adapter) + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'split_changes.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + adapter.set(storage._prefix.format(feature_flag_name=split['name']), split) + adapter.increment(storage._traffic_type_prefix.format(traffic_type_name=split['trafficTypeName']), 1) + adapter.set(storage._feature_flag_till_prefix, data['ff']['t']) + + split_objects = [splits.from_raw(raw) for raw in data['ff']['d']] + for split_object in split_objects: + raw = split_object.to_json() + + original_splits = {split.name: split for split in split_objects} + fetched_splits = {name: storage.get(name) for name in original_splits.keys()} + + assert set(original_splits.keys()) == set(fetched_splits.keys()) + + for original_split in original_splits.values(): + fetched_split = fetched_splits[original_split.name] + assert original_split.traffic_type_name == fetched_split.traffic_type_name + assert original_split.seed == fetched_split.seed + assert original_split.algo == fetched_split.algo + assert original_split.status == fetched_split.status + assert original_split.change_number == fetched_split.change_number + assert original_split.killed == fetched_split.killed + assert original_split.default_treatment == fetched_split.default_treatment + for index, original_condition in enumerate(original_split.conditions): + fetched_condition = fetched_split.conditions[index] + assert original_condition.label == fetched_condition.label + assert original_condition.condition_type == fetched_condition.condition_type + assert len(original_condition.matchers) == len(fetched_condition.matchers) + assert len(original_condition.partitions) == len(fetched_condition.partitions) + + adapter.set(storage._feature_flag_till_prefix, data['ff']['t']) + assert storage.get_change_number() == data['ff']['t'] + + assert storage.is_valid_traffic_type('user') is True + assert storage.is_valid_traffic_type('account') is True + assert storage.is_valid_traffic_type('anything-else') is False + + finally: + to_delete = [ + "SPLITIO.split.sample_feature", + "SPLITIO.splits.till", + "SPLITIO.split.all_feature", + "SPLITIO.split.killed_feature", + "SPLITIO.split.Risk_Max_Deductible", + "SPLITIO.split.whitelist_feature", + "SPLITIO.split.regex_test", + "SPLITIO.split.boolean_test", + "SPLITIO.split.dependency_test", + "SPLITIO.trafficType.user", + "SPLITIO.trafficType.account" + ] + for item in to_delete: + adapter.delete(item) + + storage = PluggableSplitStorage(adapter) + assert storage.is_valid_traffic_type('user') is False + assert storage.is_valid_traffic_type('account') is False + + def test_get_all(self): + """Test get all names & splits.""" + adapter = StorageMockAdapter() + try: + storage = PluggableSplitStorage(adapter) + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'split_changes.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + adapter.set(storage._prefix.format(feature_flag_name=split['name']), split) + adapter.increment(storage._traffic_type_prefix.format(traffic_type_name=split['trafficTypeName']), 1) + adapter.set(storage._feature_flag_till_prefix, data['ff']['t']) + + split_objects = [splits.from_raw(raw) for raw in data['ff']['d']] + original_splits = {split.name: split for split in split_objects} + fetched_names = storage.get_split_names() + fetched_splits = {split.name: split for split in storage.get_all_splits()} + assert set(fetched_names) == set(fetched_splits.keys()) + + for original_split in original_splits.values(): + fetched_split = fetched_splits[original_split.name] + assert original_split.traffic_type_name == fetched_split.traffic_type_name + assert original_split.seed == fetched_split.seed + assert original_split.algo == fetched_split.algo + assert original_split.status == fetched_split.status + assert original_split.change_number == fetched_split.change_number + assert original_split.killed == fetched_split.killed + assert original_split.default_treatment == fetched_split.default_treatment + for index, original_condition in enumerate(original_split.conditions): + fetched_condition = fetched_split.conditions[index] + assert original_condition.label == fetched_condition.label + assert original_condition.condition_type == fetched_condition.condition_type + assert len(original_condition.matchers) == len(fetched_condition.matchers) + assert len(original_condition.partitions) == len(fetched_condition.partitions) + finally: + [adapter.delete(key) for key in ['SPLITIO.split.sample_feature', + 'SPLITIO.splits.till', + 'SPLITIO.split.all_feature', + 'SPLITIO.split.killed_feature', + 'SPLITIO.split.Risk_Max_Deductible', + 'SPLITIO.split.whitelist_feature', + 'SPLITIO.split.regex_test', + 'SPLITIO.split.boolean_test', + 'SPLITIO.split.dependency_test']] + + +class PluggableSegmentStorageIntegrationTests(object): + """Pluggable Segment storage e2e tests.""" + + def test_put_fetch_contains(self): + """Test storing and retrieving splits in pluggable.""" + adapter = StorageMockAdapter() + try: + storage = PluggableSegmentStorage(adapter) + adapter.set(storage._prefix.format(segment_name='some_segment'), {'key1', 'key2', 'key3', 'key4'}) + adapter.set(storage._segment_till_prefix.format(segment_name='some_segment'), 123) + assert storage.segment_contains('some_segment', 'key0') is False + assert storage.segment_contains('some_segment', 'key1') is True + assert storage.segment_contains('some_segment', 'key2') is True + assert storage.segment_contains('some_segment', 'key3') is True + assert storage.segment_contains('some_segment', 'key4') is True + assert storage.segment_contains('some_segment', 'key5') is False + + fetched = storage.get('some_segment') + assert fetched.keys == set(['key1', 'key2', 'key3', 'key4']) + assert fetched.change_number == 123 + finally: + adapter.delete('SPLITIO.segment.some_segment') + adapter.delete('SPLITIO.segment.some_segment.till') + + +class PluggableImpressionsStorageIntegrationTests(object): + """Pluggable Impressions storage e2e tests.""" + + def _put_impressions(self, adapter, metadata): + storage = PluggableImpressionsStorage(adapter, metadata) + storage.put([ + impressions.Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None), + impressions.Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None), + impressions.Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) + ]) + + + def test_put_fetch_contains(self): + """Test storing and retrieving splits in pluggable.""" + adapter = StorageMockAdapter() + try: + self._put_impressions(adapter, get_metadata({})) + + imps = adapter.pop_items('SPLITIO.impressions') + assert len(imps) == 3 + for rawImpression in imps: + impression = json.loads(rawImpression) + assert impression['m']['i'] != 'NA' + assert impression['m']['n'] != 'NA' + finally: + adapter.delete('SPLITIO.impressions') + + def test_put_fetch_contains_ip_address_disabled(self): + """Test storing and retrieving splits in pluggable.""" + adapter = StorageMockAdapter() + try: + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': False}) + self._put_impressions(adapter, get_metadata(cfg)) + + imps = adapter.pop_items('SPLITIO.impressions') + assert len(imps) == 3 + for rawImpression in imps: + impression = json.loads(rawImpression) + assert impression['m']['i'] == 'NA' + assert impression['m']['n'] == 'NA' + finally: + adapter.delete('SPLITIO.impressions') + + +class PluggableEventsStorageIntegrationTests(object): + """Pluggable Events storage e2e tests.""" + def _put_events(self, adapter, metadata): + storage = PluggableEventsStorage(adapter, metadata) + storage.put([ + events.EventWrapper( + event=events.Event('key1', 'user', 'purchase', 3.5, 123456, None), + size=1024, + ), + events.EventWrapper( + event=events.Event('key2', 'user', 'purchase', 3.5, 123456, None), + size=1024, + ), + events.EventWrapper( + event=events.Event('key3', 'user', 'purchase', 3.5, 123456, None), + size=1024, + ), + ]) + + def test_put_fetch_contains(self): + """Test storing and retrieving splits in pluggable.""" + adapter = StorageMockAdapter() + try: + self._put_events(adapter, get_metadata({})) + evts = adapter.pop_items('SPLITIO.events') + assert len(evts) == 3 + for rawEvent in evts: + event = json.loads(rawEvent) + assert event['m']['i'] != 'NA' + assert event['m']['n'] != 'NA' + finally: + adapter.delete('SPLITIO.events') + + def test_put_fetch_contains_ip_address_disabled(self): + """Test storing and retrieving splits in pluggable.""" + adapter = StorageMockAdapter() + try: + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': False}) + self._put_events(adapter, get_metadata(cfg)) + + evts = adapter.pop_items('SPLITIO.events') + assert len(evts) == 3 + for rawEvent in evts: + event = json.loads(rawEvent) + assert event['m']['i'] == 'NA' + assert event['m']['n'] == 'NA' + finally: + adapter.delete('SPLITIO.events') + + +class PluggableSplitStorageIntegrationAsyncTests(object): + """Pluggable Split storage e2e tests.""" + + @pytest.mark.asyncio + async def test_put_fetch(self): + """Test storing and retrieving splits in pluggable.""" + adapter = StorageMockAdapterAsync() + try: + storage = PluggableSplitStorageAsync(adapter) + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'split_changes.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + await adapter.set(storage._prefix.format(feature_flag_name=split['name']), split) + await adapter.increment(storage._traffic_type_prefix.format(traffic_type_name=split['trafficTypeName']), 1) + await adapter.set(storage._feature_flag_till_prefix, data['ff']['t']) + + split_objects = [splits.from_raw(raw) for raw in data['ff']['d']] + for split_object in split_objects: + raw = split_object.to_json() + + original_splits = {split.name: split for split in split_objects} + fetched_splits = {name: await storage.get(name) for name in original_splits.keys()} + + assert set(original_splits.keys()) == set(fetched_splits.keys()) + + for original_split in original_splits.values(): + fetched_split = fetched_splits[original_split.name] + assert original_split.traffic_type_name == fetched_split.traffic_type_name + assert original_split.seed == fetched_split.seed + assert original_split.algo == fetched_split.algo + assert original_split.status == fetched_split.status + assert original_split.change_number == fetched_split.change_number + assert original_split.killed == fetched_split.killed + assert original_split.default_treatment == fetched_split.default_treatment + for index, original_condition in enumerate(original_split.conditions): + fetched_condition = fetched_split.conditions[index] + assert original_condition.label == fetched_condition.label + assert original_condition.condition_type == fetched_condition.condition_type + assert len(original_condition.matchers) == len(fetched_condition.matchers) + assert len(original_condition.partitions) == len(fetched_condition.partitions) + + await adapter.set(storage._feature_flag_till_prefix, data['ff']['t']) + assert await storage.get_change_number() == data['ff']['t'] + + assert await storage.is_valid_traffic_type('user') is True + assert await storage.is_valid_traffic_type('account') is True + assert await storage.is_valid_traffic_type('anything-else') is False + + finally: + to_delete = [ + "SPLITIO.split.sample_feature", + "SPLITIO.splits.till", + "SPLITIO.split.all_feature", + "SPLITIO.split.killed_feature", + "SPLITIO.split.Risk_Max_Deductible", + "SPLITIO.split.whitelist_feature", + "SPLITIO.split.regex_test", + "SPLITIO.split.boolean_test", + "SPLITIO.split.dependency_test", + "SPLITIO.trafficType.user", + "SPLITIO.trafficType.account" + ] + for item in to_delete: + await adapter.delete(item) + + storage = PluggableSplitStorageAsync(adapter) + assert await storage.is_valid_traffic_type('user') is False + assert await storage.is_valid_traffic_type('account') is False + + @pytest.mark.asyncio + async def test_get_all(self): + """Test get all names & splits.""" + adapter = StorageMockAdapterAsync() + try: + storage = PluggableSplitStorageAsync(adapter) + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'split_changes.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['ff']['d']: + await adapter.set(storage._prefix.format(feature_flag_name=split['name']), split) + await adapter.increment(storage._traffic_type_prefix.format(traffic_type_name=split['trafficTypeName']), 1) + await adapter.set(storage._feature_flag_till_prefix, data['ff']['t']) + + split_objects = [splits.from_raw(raw) for raw in data['ff']['d']] + original_splits = {split.name: split for split in split_objects} + fetched_names = await storage.get_split_names() + fetched_splits = {split.name: split for split in await storage.get_all_splits()} + assert set(fetched_names) == set(fetched_splits.keys()) + + for original_split in original_splits.values(): + fetched_split = fetched_splits[original_split.name] + assert original_split.traffic_type_name == fetched_split.traffic_type_name + assert original_split.seed == fetched_split.seed + assert original_split.algo == fetched_split.algo + assert original_split.status == fetched_split.status + assert original_split.change_number == fetched_split.change_number + assert original_split.killed == fetched_split.killed + assert original_split.default_treatment == fetched_split.default_treatment + for index, original_condition in enumerate(original_split.conditions): + fetched_condition = fetched_split.conditions[index] + assert original_condition.label == fetched_condition.label + assert original_condition.condition_type == fetched_condition.condition_type + assert len(original_condition.matchers) == len(fetched_condition.matchers) + assert len(original_condition.partitions) == len(fetched_condition.partitions) + finally: + [await adapter.delete(key) for key in ['SPLITIO.split.sample_feature', + 'SPLITIO.splits.till', + 'SPLITIO.split.all_feature', + 'SPLITIO.split.killed_feature', + 'SPLITIO.split.Risk_Max_Deductible', + 'SPLITIO.split.whitelist_feature', + 'SPLITIO.split.regex_test', + 'SPLITIO.split.boolean_test', + 'SPLITIO.split.dependency_test']] + + +class PluggableSegmentStorageIntegrationAsyncTests(object): + """Pluggable Segment storage e2e tests.""" + + @pytest.mark.asyncio + async def test_put_fetch_contains(self): + """Test storing and retrieving splits in pluggable.""" + adapter = StorageMockAdapterAsync() + try: + storage = PluggableSegmentStorageAsync(adapter) + await adapter.set(storage._prefix.format(segment_name='some_segment'), {'key1', 'key2', 'key3', 'key4'}) + await adapter.set(storage._segment_till_prefix.format(segment_name='some_segment'), 123) + assert await storage.segment_contains('some_segment', 'key0') is False + assert await storage.segment_contains('some_segment', 'key1') is True + assert await storage.segment_contains('some_segment', 'key2') is True + assert await storage.segment_contains('some_segment', 'key3') is True + assert await storage.segment_contains('some_segment', 'key4') is True + assert await storage.segment_contains('some_segment', 'key5') is False + + fetched = await storage.get('some_segment') + assert fetched.keys == set(['key1', 'key2', 'key3', 'key4']) + assert fetched.change_number == 123 + finally: + await adapter.delete('SPLITIO.segment.some_segment') + await adapter.delete('SPLITIO.segment.some_segment.till') + +class PluggableEventsStorageIntegrationAsyncTests(object): + """Pluggable Events storage e2e tests.""" + async def _put_events(self, adapter, metadata): + storage = PluggableEventsStorageAsync(adapter, metadata) + await storage.put([ + events.EventWrapper( + event=events.Event('key1', 'user', 'purchase', 3.5, 123456, None), + size=1024, + ), + events.EventWrapper( + event=events.Event('key2', 'user', 'purchase', 3.5, 123456, None), + size=1024, + ), + events.EventWrapper( + event=events.Event('key3', 'user', 'purchase', 3.5, 123456, None), + size=1024, + ), + ]) + + @pytest.mark.asyncio + async def test_put_fetch_contains(self): + """Test storing and retrieving splits in pluggable.""" + adapter = StorageMockAdapterAsync() + try: + await self._put_events(adapter, get_metadata({})) + evts = await adapter.pop_items('SPLITIO.events') + assert len(evts) == 3 + for rawEvent in evts: + event = json.loads(rawEvent) + assert event['m']['i'] != 'NA' + assert event['m']['n'] != 'NA' + finally: + await adapter.delete('SPLITIO.events') + + @pytest.mark.asyncio + async def test_put_fetch_contains_ip_address_disabled(self): + """Test storing and retrieving splits in pluggable.""" + adapter = StorageMockAdapterAsync() + try: + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': False}) + await self._put_events(adapter, get_metadata(cfg)) + + evts = await adapter.pop_items('SPLITIO.events') + assert len(evts) == 3 + for rawEvent in evts: + event = json.loads(rawEvent) + assert event['m']['i'] == 'NA' + assert event['m']['n'] == 'NA' + finally: + await adapter.delete('SPLITIO.events') diff --git a/tests/integration/test_redis_integration.py b/tests/integration/test_redis_integration.py index 685f72c5..4c85beda 100644 --- a/tests/integration/test_redis_integration.py +++ b/tests/integration/test_redis_integration.py @@ -1,32 +1,37 @@ """Redis storage end to end tests.""" #pylint: disable=no-self-use,protected-access,line-too-long,too-few-public-methods - +import pytest import json import os from splitio.client.util import get_metadata from splitio.models import splits, impressions, events from splitio.storage.redis import RedisSplitStorage, RedisSegmentStorage, RedisImpressionsStorage, \ - RedisEventsStorage -from splitio.storage.adapters.redis import _build_default_client + RedisEventsStorage, RedisEventsStorageAsync, RedisImpressionsStorageAsync, RedisSegmentStorageAsync, \ + RedisSplitStorageAsync +from splitio.storage.adapters.redis import _build_default_client, _build_default_client_async, StrictRedis from splitio.client.config import DEFAULT_CONFIG -class SplitStorageTests(object): +class RedisSplitStorageTests(object): """Redis Split storage e2e tests.""" def test_put_fetch(self): """Test storing and retrieving splits in redis.""" - adapter = _build_default_client({}) + redis = StrictRedis(host="localhost") + redis.acl_setuser(username='redis_user', enabled=True, passwords=["+split"], categories=["+admin"], + commands=["+@all"], keys=["~*"]) + redis.close() + adapter = _build_default_client({'redisUsername': 'redis_user', 'redisPassword': 'split'}) try: storage = RedisSplitStorage(adapter) with open(os.path.join(os.path.dirname(__file__), 'files', 'split_changes.json'), 'r') as flo: split_changes = json.load(flo) - split_objects = [splits.from_raw(raw) for raw in split_changes['splits']] + split_objects = [splits.from_raw(raw) for raw in split_changes['ff']['d']] for split_object in split_objects: raw = split_object.to_json() - adapter.set(RedisSplitStorage._SPLIT_KEY.format(split_name=split_object.name), json.dumps(raw)) + adapter.set(RedisSplitStorage._FEATURE_FLAG_KEY.format(feature_flag_name=split_object.name), json.dumps(raw)) adapter.incr(RedisSplitStorage._TRAFFIC_TYPE_KEY.format(traffic_type_name=split_object.traffic_type_name)) original_splits = {split.name: split for split in split_objects} @@ -50,8 +55,8 @@ def test_put_fetch(self): assert len(original_condition.matchers) == len(fetched_condition.matchers) assert len(original_condition.partitions) == len(fetched_condition.partitions) - adapter.set(RedisSplitStorage._SPLIT_TILL_KEY, split_changes['till']) - assert storage.get_change_number() == split_changes['till'] + adapter.set(RedisSplitStorage._FEATURE_FLAG_TILL_KEY, split_changes['ff']['t']) + assert storage.get_change_number() == split_changes['ff']['t'] assert storage.is_valid_traffic_type('user') is True assert storage.is_valid_traffic_type('account') is True @@ -73,10 +78,12 @@ def test_put_fetch(self): ] for item in to_delete: adapter.delete(item) - storage = RedisSplitStorage(adapter) assert storage.is_valid_traffic_type('user') is False assert storage.is_valid_traffic_type('account') is False + redis = StrictRedis(host="localhost") + redis.acl_deluser("redis_user") + redis.close() def test_get_all(self): """Test get all names & splits.""" @@ -86,10 +93,10 @@ def test_get_all(self): with open(os.path.join(os.path.dirname(__file__), 'files', 'split_changes.json'), 'r') as flo: split_changes = json.load(flo) - split_objects = [splits.from_raw(raw) for raw in split_changes['splits']] + split_objects = [splits.from_raw(raw) for raw in split_changes['ff']['d']] for split_object in split_objects: raw = split_object.to_json() - adapter.set(RedisSplitStorage._SPLIT_KEY.format(split_name=split_object.name), json.dumps(raw)) + adapter.set(RedisSplitStorage._FEATURE_FLAG_KEY.format(feature_flag_name=split_object.name), json.dumps(raw)) original_splits = {split.name: split for split in split_objects} fetched_names = storage.get_split_names() @@ -124,7 +131,7 @@ def test_get_all(self): 'SPLITIO.split.dependency_test' ) -class SegmentStorageTests(object): +class RedisSegmentStorageTests(object): """Redis Segment storage e2e tests.""" def test_put_fetch_contains(self): @@ -134,12 +141,12 @@ def test_put_fetch_contains(self): storage = RedisSegmentStorage(adapter) adapter.sadd(storage._get_key('some_segment'), 'key1', 'key2', 'key3', 'key4') adapter.set(storage._get_till_key('some_segment'), 123) - assert storage.segment_contains('some_segment', 'key0') is False - assert storage.segment_contains('some_segment', 'key1') is True - assert storage.segment_contains('some_segment', 'key2') is True - assert storage.segment_contains('some_segment', 'key3') is True - assert storage.segment_contains('some_segment', 'key4') is True - assert storage.segment_contains('some_segment', 'key5') is False + assert storage.segment_contains('some_segment', 'key0') == 0 + assert storage.segment_contains('some_segment', 'key1') == 1 + assert storage.segment_contains('some_segment', 'key2') == 1 + assert storage.segment_contains('some_segment', 'key3') == 1 + assert storage.segment_contains('some_segment', 'key4') == 1 + assert storage.segment_contains('some_segment', 'key5') == 0 fetched = storage.get('some_segment') assert fetched.keys == set(['key1', 'key2', 'key3', 'key4']) @@ -148,15 +155,15 @@ def test_put_fetch_contains(self): adapter.delete('SPLITIO.segment.some_segment', 'SPLITIO.segment.some_segment.till') -class ImpressionsStorageTests(object): +class RedisImpressionsStorageTests(object): """Redis Impressions storage e2e tests.""" def _put_impressions(self, adapter, metadata): storage = RedisImpressionsStorage(adapter, metadata) storage.put([ - impressions.Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654), - impressions.Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654), - impressions.Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + impressions.Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None), + impressions.Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None), + impressions.Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) ]) @@ -193,7 +200,7 @@ def test_put_fetch_contains_ip_address_disabled(self): adapter.delete('SPLITIO.impressions') -class EventsStorageTests(object): +class RedisEventsStorageTests(object): """Redis Events storage e2e tests.""" def _put_events(self, adapter, metadata): storage = RedisEventsStorage(adapter, metadata) @@ -242,3 +249,240 @@ def test_put_fetch_contains_ip_address_disabled(self): assert event['m']['n'] == 'NA' finally: adapter.delete('SPLITIO.events') + +class RedisSplitStorageAsyncTests(object): + """Redis Split storage e2e tests.""" + + @pytest.mark.asyncio + async def test_put_fetch(self): + """Test storing and retrieving splits in redis.""" + adapter = await _build_default_client_async({}) + try: + storage = RedisSplitStorageAsync(adapter) + with open(os.path.join(os.path.dirname(__file__), 'files', 'split_changes.json'), 'r') as flo: + split_changes = json.load(flo) + + split_objects = [splits.from_raw(raw) for raw in split_changes['ff']['d']] + for split_object in split_objects: + raw = split_object.to_json() + await adapter.set(RedisSplitStorage._FEATURE_FLAG_KEY.format(feature_flag_name=split_object.name), json.dumps(raw)) + await adapter.incr(RedisSplitStorage._TRAFFIC_TYPE_KEY.format(traffic_type_name=split_object.traffic_type_name)) + + original_splits = {split.name: split for split in split_objects} + fetched_splits = {name: await storage.get(name) for name in original_splits.keys()} + + assert set(original_splits.keys()) == set(fetched_splits.keys()) + + for original_split in original_splits.values(): + fetched_split = fetched_splits[original_split.name] + assert original_split.traffic_type_name == fetched_split.traffic_type_name + assert original_split.seed == fetched_split.seed + assert original_split.algo == fetched_split.algo + assert original_split.status == fetched_split.status + assert original_split.change_number == fetched_split.change_number + assert original_split.killed == fetched_split.killed + assert original_split.default_treatment == fetched_split.default_treatment + for index, original_condition in enumerate(original_split.conditions): + fetched_condition = fetched_split.conditions[index] + assert original_condition.label == fetched_condition.label + assert original_condition.condition_type == fetched_condition.condition_type + assert len(original_condition.matchers) == len(fetched_condition.matchers) + assert len(original_condition.partitions) == len(fetched_condition.partitions) + + await adapter.set(RedisSplitStorageAsync._FEATURE_FLAG_TILL_KEY, split_changes['ff']['t']) + assert await storage.get_change_number() == split_changes['ff']['t'] + + assert await storage.is_valid_traffic_type('user') is True + assert await storage.is_valid_traffic_type('account') is True + assert await storage.is_valid_traffic_type('anything-else') is False + + finally: + to_delete = [ + "SPLITIO.split.sample_feature", + "SPLITIO.splits.till", + "SPLITIO.split.all_feature", + "SPLITIO.split.killed_feature", + "SPLITIO.split.Risk_Max_Deductible", + "SPLITIO.split.whitelist_feature", + "SPLITIO.split.regex_test", + "SPLITIO.split.boolean_test", + "SPLITIO.split.dependency_test", + "SPLITIO.trafficType.user", + "SPLITIO.trafficType.account" + ] + for item in to_delete: + await adapter.delete(item) + + storage = RedisSplitStorageAsync(adapter) + assert await storage.is_valid_traffic_type('user') is False + assert await storage.is_valid_traffic_type('account') is False + + @pytest.mark.asyncio + async def test_get_all(self): + """Test get all names & splits.""" + adapter = await _build_default_client_async({}) + try: + storage = RedisSplitStorageAsync(adapter) + with open(os.path.join(os.path.dirname(__file__), 'files', 'split_changes.json'), 'r') as flo: + split_changes = json.load(flo) + + split_objects = [splits.from_raw(raw) for raw in split_changes['ff']['d']] + for split_object in split_objects: + raw = split_object.to_json() + await adapter.set(RedisSplitStorageAsync._FEATURE_FLAG_KEY.format(feature_flag_name=split_object.name), json.dumps(raw)) + + original_splits = {split.name: split for split in split_objects} + fetched_names = await storage.get_split_names() + fetched_splits = {split.name: split for split in await storage.get_all_splits()} + assert set(fetched_names) == set(fetched_splits.keys()) + + for original_split in original_splits.values(): + fetched_split = fetched_splits[original_split.name] + assert original_split.traffic_type_name == fetched_split.traffic_type_name + assert original_split.seed == fetched_split.seed + assert original_split.algo == fetched_split.algo + assert original_split.status == fetched_split.status + assert original_split.change_number == fetched_split.change_number + assert original_split.killed == fetched_split.killed + assert original_split.default_treatment == fetched_split.default_treatment + for index, original_condition in enumerate(original_split.conditions): + fetched_condition = fetched_split.conditions[index] + assert original_condition.label == fetched_condition.label + assert original_condition.condition_type == fetched_condition.condition_type + assert len(original_condition.matchers) == len(fetched_condition.matchers) + assert len(original_condition.partitions) == len(fetched_condition.partitions) + finally: + await adapter.delete( + 'SPLITIO.split.sample_feature', + 'SPLITIO.splits.till', + 'SPLITIO.split.all_feature', + 'SPLITIO.split.killed_feature', + 'SPLITIO.split.Risk_Max_Deductible', + 'SPLITIO.split.whitelist_feature', + 'SPLITIO.split.regex_test', + 'SPLITIO.split.boolean_test', + 'SPLITIO.split.dependency_test' + ) + +class RedisSegmentStorageAsyncTests(object): + """Redis Segment storage e2e tests.""" + + @pytest.mark.asyncio + async def test_put_fetch_contains(self): + """Test storing and retrieving splits in redis.""" + adapter = await _build_default_client_async({}) + try: + storage = RedisSegmentStorageAsync(adapter) + await adapter.sadd(storage._get_key('some_segment'), 'key1', 'key2', 'key3', 'key4') + await adapter.set(storage._get_till_key('some_segment'), 123) + assert await storage.segment_contains('some_segment', 'key0') == 0 + assert await storage.segment_contains('some_segment', 'key1') == 1 + assert await storage.segment_contains('some_segment', 'key2') == 1 + assert await storage.segment_contains('some_segment', 'key3') == 1 + assert await storage.segment_contains('some_segment', 'key4') == 1 + assert await storage.segment_contains('some_segment', 'key5') == 0 + + fetched = await storage.get('some_segment') + assert fetched.keys == set(['key1', 'key2', 'key3', 'key4']) + assert fetched.change_number == 123 + finally: + await adapter.delete('SPLITIO.segment.some_segment', 'SPLITIO.segment.some_segment.till') + +class RedisImpressionsStorageAsyncTests(object): + """Redis Impressions storage e2e tests.""" + + async def _put_impressions(self, adapter, metadata): + storage = RedisImpressionsStorageAsync(adapter, metadata) + await storage.put([ + impressions.Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None), + impressions.Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None), + impressions.Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) + ]) + + + @pytest.mark.asyncio + async def test_put_fetch_contains(self): + """Test storing and retrieving splits in redis.""" + adapter = await _build_default_client_async({}) + try: + await self._put_impressions(adapter, get_metadata({})) + + imps = await adapter.lrange('SPLITIO.impressions', 0, 2) + assert len(imps) == 3 + for rawImpression in imps: + impression = json.loads(rawImpression) + assert impression['m']['i'] != 'NA' + assert impression['m']['n'] != 'NA' + finally: + await adapter.delete('SPLITIO.impressions') + + @pytest.mark.asyncio + async def test_put_fetch_contains_ip_address_disabled(self): + """Test storing and retrieving splits in redis.""" + adapter = await _build_default_client_async({}) + try: + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': False}) + await self._put_impressions(adapter, get_metadata(cfg)) + + imps = await adapter.lrange('SPLITIO.impressions', 0, 2) + assert len(imps) == 3 + for rawImpression in imps: + impression = json.loads(rawImpression) + assert impression['m']['i'] == 'NA' + assert impression['m']['n'] == 'NA' + finally: + await adapter.delete('SPLITIO.impressions') + + +class RedisEventsStorageAsyncTests(object): + """Redis Events storage e2e tests.""" + async def _put_events(self, adapter, metadata): + storage = RedisEventsStorageAsync(adapter, metadata) + await storage.put([ + events.EventWrapper( + event=events.Event('key1', 'user', 'purchase', 3.5, 123456, None), + size=1024, + ), + events.EventWrapper( + event=events.Event('key2', 'user', 'purchase', 3.5, 123456, None), + size=1024, + ), + events.EventWrapper( + event=events.Event('key3', 'user', 'purchase', 3.5, 123456, None), + size=1024, + ), + ]) + + @pytest.mark.asyncio + async def test_put_fetch_contains(self): + """Test storing and retrieving splits in redis.""" + adapter = await _build_default_client_async({}) + try: + await self._put_events(adapter, get_metadata({})) + evts = await adapter.lrange('SPLITIO.events', 0, 2) + assert len(evts) == 3 + for rawEvent in evts: + event = json.loads(rawEvent) + assert event['m']['i'] != 'NA' + assert event['m']['n'] != 'NA' + finally: + await adapter.delete('SPLITIO.events') + + @pytest.mark.asyncio + async def test_put_fetch_contains_ip_address_disabled(self): + """Test storing and retrieving splits in redis.""" + adapter = await _build_default_client_async({}) + try: + cfg = DEFAULT_CONFIG.copy() + cfg.update({'IPAddressesEnabled': False}) + await self._put_events(adapter, get_metadata(cfg)) + + evts = await adapter.lrange('SPLITIO.events', 0, 2) + assert len(evts) == 3 + for rawEvent in evts: + event = json.loads(rawEvent) + assert event['m']['i'] == 'NA' + assert event['m']['n'] == 'NA' + finally: + await adapter.delete('SPLITIO.events') diff --git a/tests/integration/test_streaming_e2e.py b/tests/integration/test_streaming_e2e.py index 50391baa..d7b3103a 100644 --- a/tests/integration/test_streaming_e2e.py +++ b/tests/integration/test_streaming_e2e.py @@ -4,16 +4,1374 @@ import threading import time import json +import base64 +import pytest + from queue import Queue -from splitio.client.factory import get_factory +from splitio.optional.loaders import asyncio +from splitio.client.factory import get_factory, get_factory_async +from splitio.models.events import SdkEvent +from splitio.events.events_metadata import SdkEventType from tests.helpers.mockserver import SSEMockServer, SplitMockServer from urllib.parse import parse_qs +from splitio.models.telemetry import StreamingEventTypes, SSESyncMode + + +class StreamingIntegrationTests(object): + """Test streaming operation and failover.""" + + update_flag = False + metadata = [] + + def test_happiness(self): + """Test initialization & splits/segment updates.""" + auth_server_response = { + 'pushEnabled': True, + 'token': ('eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.' + 'eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pO' + 'RFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjcmliZVwiXSxcIk1UWXlNVGN4T1RRNE13P' + 'T1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcIjpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm' + '9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJ' + 'zXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzdWJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRh' + 'dGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFibHktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4c' + 'CI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0MDk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5E' + 'vJh17WlOlAKhcD0') + } + + split_changes = { + -1: {'ff': { + 's': -1, + 't': 1, + 'd': [make_simple_split('split1', 1, True, False, 'on', 'user', True)]}, + 'rbs': {'s': -1, 't': -1, 'd': []} + }, + 1: {'ff': { + 's': 1, + 't': 1, + 'd': []}, + 'rbs': {'s': -1, 't': -1, 'd': []} + } + } + + segment_changes = {} + split_backend_requests = Queue() + split_backend = SplitMockServer(split_changes, segment_changes, split_backend_requests, + auth_server_response) + sse_requests = Queue() + sse_server = SSEMockServer(sse_requests) + + split_backend.start() + sse_server.start() + sse_server.publish(make_initial_event()) + sse_server.publish(make_occupancy('control_pri', 2)) + sse_server.publish(make_occupancy('control_sec', 2)) + + kwargs = { + 'sdk_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'events_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'auth_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'streaming_api_base_url': 'http://localhost:%d' % sse_server.port(), + 'config': {'connectTimeout': 10000} + } + + factory = get_factory('some_apikey', **kwargs) + factory.client().on(SdkEvent.SDK_UPDATE, self._update_callcack) + factory.block_until_ready(1) + assert factory.ready + assert factory.client().get_treatment('maldo', 'split1') == 'on' + + time.sleep(1) + assert(factory._telemetry_evaluation_producer._telemetry_storage._streaming_events._streaming_events[len(factory._telemetry_evaluation_producer._telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.SYNC_MODE_UPDATE.value) + assert(factory._telemetry_evaluation_producer._telemetry_storage._streaming_events._streaming_events[len(factory._telemetry_evaluation_producer._telemetry_storage._streaming_events._streaming_events)-1]._data == SSESyncMode.STREAMING.value) + split_changes[1] = { + 'ff': { + 's': 1, + 't': 2, + 'd': [make_simple_split('split1', 2, True, False, 'off', 'user', False)]}, + 'rbs': {'s': -1, 't': -1, 'd': []} + } + split_changes[2] = {'ff': {'s': 2, 't': 2, 'd': []}, 'rbs': {'s': -1, 't': -1, 'd': []}} + sse_server.publish(make_split_change_event(2)) + time.sleep(1) + flag = False + for meta in self.metadata: + if 'split1' in meta.get_names(): + assert meta.get_type() == SdkEventType.FLAG_UPDATE + flag = True + assert flag + + assert factory.client().get_treatment('maldo', 'split1') == 'off' + + split_changes[2] = { + 'ff': { + 's': 2, + 't': 3, + 'd': [make_split_with_segment('split2', 2, True, False, + 'off', 'user', 'off', 'segment1')]}, + 'rbs': {'s': -1, 't': -1, 'd': []} + } + split_changes[3] = {'ff': {'s': 3, 't': 3, 'd': []}, 'rbs': {'s': -1, 't': -1, 'd': []}} + segment_changes[('segment1', -1)] = { + 'name': 'segment1', + 'added': ['maldo'], + 'removed': [], + 'since': -1, + 'till': 1 + } + segment_changes[('segment1', 1)] = {'name': 'segment1', 'added': [], + 'removed': [], 'since': 1, 'till': 1} + + sse_server.publish(make_split_change_event(3)) + time.sleep(1) + + self._reset_flags() + sse_server.publish(make_segment_change_event('segment1', 1)) + time.sleep(1) + assert self.update_flag + assert self.metadata[len(self.metadata)-1].get_type() == SdkEventType.SEGMENTS_UPDATE + flag = False + for meta in self.metadata: + if 'split2' in meta.get_names(): + assert meta.get_type() == SdkEventType.FLAG_UPDATE + flag = True + assert flag + + assert factory.client().get_treatment('pindon', 'split2') == 'off' + assert factory.client().get_treatment('maldo', 'split2') == 'on' + + self._reset_flags() + sse_server.publish(make_split_fast_change_event(4)) + time.sleep(1) + assert self.update_flag + assert self.metadata[len(self.metadata)-1].get_type() == SdkEventType.FLAG_UPDATE + assert 'split5' in self.metadata[len(self.metadata)-1].get_names() + assert factory.client().get_treatment('maldo', 'split5') == 'on' + + # Validate the SSE request + sse_request = sse_requests.get() + assert sse_request.method == 'GET' + path, qs = sse_request.path.split('?', 1) + assert path == '/event-stream' + qs = parse_qs(qs) + assert qs['accessToken'][0] == ( + 'eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05' + 'US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UW' + 'XlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjc' + 'mliZVwiXSxcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcI' + 'jpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY' + '2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJzXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzd' + 'WJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRhdGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFib' + 'HktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4cCI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0M' + 'Dk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5EvJh17WlOlAKhcD0' + ) + + assert set(qs['channels'][0].split(',')) == set(['MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_splits', + 'MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_segments', + '[?occupancy=metrics.publishers]control_pri', + '[?occupancy=metrics.publishers]control_sec']) + assert qs['v'][0] == '1.1' + + # Initial splits fetch + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=-1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Auth + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/v2/auth?s=1.3' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll after streaming connected + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Fetch after first notification + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Fetch after second notification + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Segment change notification + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/segmentChanges/segment1?since=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until segment1 since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/segmentChanges/segment1?since=1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Cleanup + destroy_event = threading.Event() + factory.destroy(destroy_event) + destroy_event.wait() + sse_server.publish(sse_server.GRACEFUL_REQUEST_END) + sse_server.stop() + split_backend.stop() + + def _update_callcack(self, metadata): + self.update_flag = True + self.metadata.append(metadata) + + def _reset_flags(self): + self.update_flag = False + + def test_occupancy_flicker(self): + """Test that changes in occupancy switch between polling & streaming properly.""" + auth_server_response = { + 'pushEnabled': True, + 'token': ('eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.' + 'eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pO' + 'RFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjcmliZVwiXSxcIk1UWXlNVGN4T1RRNE13P' + 'T1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcIjpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm' + '9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJ' + 'zXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzdWJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRh' + 'dGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFibHktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4c' + 'CI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0MDk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5E' + 'vJh17WlOlAKhcD0') + } + + split_changes = { + -1: {'ff': { + 's': -1, + 't': 1, + 'd': [make_simple_split('split1', 1, True, False, 'off', 'user', True)]}, + 'rbs': {'s': -1, 't': -1, 'd': []} + }, + 1: {'ff': {'s': 1, 't': 1, 'd': []}, + 'rbs': {'s': -1, 't': -1, 'd': []}} + } + + segment_changes = {} + split_backend_requests = Queue() + split_backend = SplitMockServer(split_changes, segment_changes, split_backend_requests, + auth_server_response) + sse_requests = Queue() + sse_server = SSEMockServer(sse_requests) + + split_backend.start() + sse_server.start() + sse_server.publish(make_initial_event()) + sse_server.publish(make_occupancy('control_pri', 2)) + sse_server.publish(make_occupancy('control_sec', 2)) + + kwargs = { + 'sdk_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'events_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'auth_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'streaming_api_base_url': 'http://localhost:%d' % sse_server.port(), + 'config': {'connectTimeout': 10000, 'featuresRefreshRate': 10} + } + + factory = get_factory('some_apikey', **kwargs) + factory.block_until_ready(1) + assert factory.ready + time.sleep(2) + + # Get a hook of the task so we can query its status + task = factory._sync_manager._synchronizer._split_tasks.split_task._task # pylint:disable=protected-access + assert not task.running() + + assert factory.client().get_treatment('maldo', 'split1') == 'on' + + # Make a change in the BE but don't send the event. + # After dropping occupancy, the sdk should switch to polling + # and perform a syncAll that gets this change + split_changes[1] = { + 'ff': {'s': 1, + 't': 2, + 'd': [make_simple_split('split1', 2, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + } + split_changes[2] = {'ff': {'s': 2, 't': 2, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} + + sse_server.publish(make_occupancy('control_pri', 0)) + sse_server.publish(make_occupancy('control_sec', 0)) + time.sleep(2) + assert factory.client().get_treatment('maldo', 'split1') == 'off' + assert task.running() + + # We make another chagne in the BE and don't send the event. + # We restore occupancy, and it should be fetched by the + # sync all after streaming is restored. + split_changes[2] = { + 'ff': {'s': 2, + 't': 3, + 'd': [make_simple_split('split1', 3, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + } + split_changes[3] = {'ff': {'s': 3, 't': 3, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} + + sse_server.publish(make_occupancy('control_pri', 1)) + time.sleep(2) + assert factory.client().get_treatment('maldo', 'split1') == 'on' + assert not task.running() + + # Now we make another change and send an event so it's propagated + split_changes[3] = { + 'ff': {'s': 3, + 't': 4, + 'd': [make_simple_split('split1', 4, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + } + split_changes[4] = {'ff': {'s': 4, 't': 4, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} + sse_server.publish(make_split_change_event(4)) + time.sleep(2) + assert factory.client().get_treatment('maldo', 'split1') == 'off' + + # Kill the split + split_changes[4] = { + 'ff': {'s': 4, + 't': 5, + 'd': [make_simple_split('split1', 5, True, True, 'frula', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + } + split_changes[5] = {'ff': {'s': 5, 't': 5, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} + sse_server.publish(make_split_kill_event('split1', 'frula', 5)) + time.sleep(2) + assert factory.client().get_treatment('maldo', 'split1') == 'frula' + + # Validate the SSE request + sse_request = sse_requests.get() + assert sse_request.method == 'GET' + path, qs = sse_request.path.split('?', 1) + assert path == '/event-stream' + qs = parse_qs(qs) + assert qs['accessToken'][0] == ( + 'eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05' + 'US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UW' + 'XlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjc' + 'mliZVwiXSxcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcI' + 'jpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY' + '2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJzXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzd' + 'WJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRhdGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFib' + 'HktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4cCI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0M' + 'Dk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5EvJh17WlOlAKhcD0' + ) + + assert set(qs['channels'][0].split(',')) == set(['MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_splits', + 'MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_segments', + '[?occupancy=metrics.publishers]control_pri', + '[?occupancy=metrics.publishers]control_sec']) + assert qs['v'][0] == '1.1' + + # Initial splits fetch + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=-1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Auth + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/v2/auth?s=1.3' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll after streaming connected + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Fetch after first notification + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Fetch after second notification + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=4&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Split kill + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=4&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=5&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Cleanup + destroy_event = threading.Event() + factory.destroy(destroy_event) + destroy_event.wait() + sse_server.publish(sse_server.GRACEFUL_REQUEST_END) + sse_server.stop() + split_backend.stop() + + def test_start_without_occupancy(self): + """Test an SDK starting with occupancy on 0 and switching to streamin afterwards.""" + auth_server_response = { + 'pushEnabled': True, + 'token': ('eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.' + 'eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pO' + 'RFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjcmliZVwiXSxcIk1UWXlNVGN4T1RRNE13P' + 'T1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcIjpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm' + '9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJ' + 'zXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzdWJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRh' + 'dGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFibHktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4c' + 'CI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0MDk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5E' + 'vJh17WlOlAKhcD0') + } + + split_changes = { + -1: {'ff': { + 's': -1, + 't': 1, + 'd': [make_simple_split('split1', 1, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + }, + 1: {'ff': {'s': 1, 't': 1, 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []}} + } + + segment_changes = {} + split_backend_requests = Queue() + split_backend = SplitMockServer(split_changes, segment_changes, split_backend_requests, + auth_server_response) + sse_requests = Queue() + sse_server = SSEMockServer(sse_requests) + + split_backend.start() + sse_server.start() + sse_server.publish(make_initial_event()) + sse_server.publish(make_occupancy('control_pri', 0)) + sse_server.publish(make_occupancy('control_sec', 0)) + + kwargs = { + 'sdk_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'events_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'auth_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'streaming_api_base_url': 'http://localhost:%d' % sse_server.port(), + 'config': {'connectTimeout': 10000, 'featuresRefreshRate': 10} + } + + factory = get_factory('some_apikey', **kwargs) + factory.block_until_ready(1) + assert factory.ready + time.sleep(2) + + # Get a hook of the task so we can query its status + task = factory._sync_manager._synchronizer._split_tasks.split_task._task # pylint:disable=protected-access + assert task.running() + assert factory.client().get_treatment('maldo', 'split1') == 'on' + + # Make a change in the BE but don't send the event. + # After restoring occupancy, the sdk should switch to polling + # and perform a syncAll that gets this change + split_changes[1] = { + 'ff': {'s': 1, + 't': 2, + 'd': [make_simple_split('split1', 2, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + } + split_changes[2] = {'ff': {'s': 2, 't': 2, 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []}} + + sse_server.publish(make_occupancy('control_sec', 1)) + time.sleep(2) + assert factory.client().get_treatment('maldo', 'split1') == 'off' + assert not task.running() + + # Validate the SSE request + sse_request = sse_requests.get() + assert sse_request.method == 'GET' + path, qs = sse_request.path.split('?', 1) + assert path == '/event-stream' + qs = parse_qs(qs) + assert qs['accessToken'][0] == ( + 'eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05' + 'US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UW' + 'XlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjc' + 'mliZVwiXSxcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcI' + 'jpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY' + '2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJzXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzd' + 'WJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRhdGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFib' + 'HktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4cCI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0M' + 'Dk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5EvJh17WlOlAKhcD0' + ) + + assert set(qs['channels'][0].split(',')) == set(['MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_splits', + 'MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_segments', + '[?occupancy=metrics.publishers]control_pri', + '[?occupancy=metrics.publishers]control_sec']) + assert qs['v'][0] == '1.1' + + # Initial splits fetch + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=-1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Auth + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/v2/auth?s=1.3' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll after streaming connected + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll after push down + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll after push restored + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Second iteration of previous syncAll + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Cleanup + destroy_event = threading.Event() + factory.destroy(destroy_event) + destroy_event.wait() + sse_server.publish(sse_server.GRACEFUL_REQUEST_END) + sse_server.stop() + split_backend.stop() + + def test_streaming_status_changes(self): + """Test changes between streaming enabled, paused and disabled.""" + auth_server_response = { + 'pushEnabled': True, + 'token': ('eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.' + 'eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pO' + 'RFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjcmliZVwiXSxcIk1UWXlNVGN4T1RRNE13P' + 'T1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcIjpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm' + '9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJ' + 'zXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzdWJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRh' + 'dGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFibHktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4c' + 'CI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0MDk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5E' + 'vJh17WlOlAKhcD0') + } + + split_changes = { + -1: {'ff': { + 's': -1, + 't': 1, + 'd': [make_simple_split('split1', 1, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + }, + 1: {'ff': {'s': 1, 't': 1, 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []}} + } + + segment_changes = {} + split_backend_requests = Queue() + split_backend = SplitMockServer(split_changes, segment_changes, split_backend_requests, + auth_server_response) + sse_requests = Queue() + sse_server = SSEMockServer(sse_requests) + + split_backend.start() + sse_server.start() + sse_server.publish(make_initial_event()) + sse_server.publish(make_occupancy('control_pri', 2)) + sse_server.publish(make_occupancy('control_sec', 2)) + + kwargs = { + 'sdk_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'events_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'auth_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'streaming_api_base_url': 'http://localhost:%d' % sse_server.port(), + 'config': {'connectTimeout': 10000, 'featuresRefreshRate': 10} + } + + factory = get_factory('some_apikey', **kwargs) + factory.block_until_ready(1) + assert factory.ready + time.sleep(2) + + # Get a hook of the task so we can query its status + task = factory._sync_manager._synchronizer._split_tasks.split_task._task # pylint:disable=protected-access + assert not task.running() + + assert factory.client().get_treatment('maldo', 'split1') == 'on' + + # Make a change in the BE but don't send the event. + # After dropping occupancy, the sdk should switch to polling + # and perform a syncAll that gets this change + split_changes[1] = { + 'ff': {'s': 1, + 't': 2, + 'd': [make_simple_split('split1', 2, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + } + split_changes[2] = {'ff': {'s': 2, 't': 2, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} + + sse_server.publish(make_control_event('STREAMING_PAUSED', 1)) + time.sleep(2) + assert factory.client().get_treatment('maldo', 'split1') == 'off' + assert task.running() + + # We make another chagne in the BE and don't send the event. + # We restore occupancy, and it should be fetched by the + # sync all after streaming is restored. + split_changes[2] = { + 'ff': {'s': 2, + 't': 3, + 'd': [make_simple_split('split1', 3, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + } + split_changes[3] = {'ff': {'s': 3, 't': 3, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} + + sse_server.publish(make_control_event('STREAMING_ENABLED', 2)) + time.sleep(2) + assert factory.client().get_treatment('maldo', 'split1') == 'on' + assert not task.running() + + # Now we make another change and send an event so it's propagated + split_changes[3] = { + 'ff': {'s': 3, + 't': 4, + 'd': [make_simple_split('split1', 4, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + } + split_changes[4] = {'ff': {'s': 4, 't': 4, 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []}} + sse_server.publish(make_split_change_event(4)) + time.sleep(2) + assert factory.client().get_treatment('maldo', 'split1') == 'off' + assert not task.running() + + split_changes[4] = { + 'ff': {'s': 4, + 't': 5, + 'd': [make_simple_split('split1', 5, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + } + split_changes[5] = {'ff': {'s': 5, 't': 5, 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []}} + sse_server.publish(make_control_event('STREAMING_DISABLED', 2)) + time.sleep(2) + assert factory.client().get_treatment('maldo', 'split1') == 'on' + assert task.running() + assert 'PushStatusHandler' not in [t.name for t in threading.enumerate()] + + # Validate the SSE request + sse_request = sse_requests.get() + assert sse_request.method == 'GET' + path, qs = sse_request.path.split('?', 1) + assert path == '/event-stream' + qs = parse_qs(qs) + assert qs['accessToken'][0] == ( + 'eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05' + 'US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UW' + 'XlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjc' + 'mliZVwiXSxcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcI' + 'jpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY' + '2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJzXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzd' + 'WJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRhdGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFib' + 'HktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4cCI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0M' + 'Dk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5EvJh17WlOlAKhcD0' + ) + + assert set(qs['channels'][0].split(',')) == set(['MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_splits', + 'MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_segments', + '[?occupancy=metrics.publishers]control_pri', + '[?occupancy=metrics.publishers]control_sec']) + assert qs['v'][0] == '1.1' + + # Initial splits fetch + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=-1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Auth + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/v2/auth?s=1.3' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll after streaming connected + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll on push down + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll after push is up + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Fetch after notification + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=4&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll after streaming disabled + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=4&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=5&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Cleanup + destroy_event = threading.Event() + factory.destroy(destroy_event) + destroy_event.wait() + sse_server.publish(sse_server.GRACEFUL_REQUEST_END) + sse_server.stop() + split_backend.stop() + + def test_server_closes_connection(self): + """Test that if the server closes the connection, the whole flow is retried with BO.""" + auth_server_response = { + 'pushEnabled': True, + 'token': ('eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.' + 'eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pO' + 'RFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjcmliZVwiXSxcIk1UWXlNVGN4T1RRNE13P' + 'T1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcIjpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm' + '9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJ' + 'zXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzdWJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRh' + 'dGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFibHktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4c' + 'CI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0MDk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5E' + 'vJh17WlOlAKhcD0') + } + + split_changes = { + -1: {'ff': { + 's': -1, + 't': 1, + 'd': [make_simple_split('split1', 1, True, False, 'on', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + }, + 1: {'ff': { + 's': 1, + 't': 1, + 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []} + } + } + + segment_changes = {} + split_backend_requests = Queue() + split_backend = SplitMockServer(split_changes, segment_changes, split_backend_requests, + auth_server_response) + sse_requests = Queue() + sse_server = SSEMockServer(sse_requests) + + split_backend.start() + sse_server.start() + sse_server.publish(make_initial_event()) + sse_server.publish(make_occupancy('control_pri', 2)) + sse_server.publish(make_occupancy('control_sec', 2)) + + kwargs = { + 'sdk_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'events_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'auth_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'streaming_api_base_url': 'http://localhost:%d' % sse_server.port(), + 'config': {'connectTimeout': 10000, 'featuresRefreshRate': 100, + 'segmentsRefreshRate': 100, 'metricsRefreshRate': 100, + 'impressionsRefreshRate': 100, 'eventsPushRate': 100} + } + + factory = get_factory('some_apikey', **kwargs) + factory.block_until_ready(1) + assert factory.ready + assert factory.client().get_treatment('maldo', 'split1') == 'on' + task = factory._sync_manager._synchronizer._split_tasks.split_task._task # pylint:disable=protected-access + assert not task.running() + + time.sleep(1) + split_changes[1] = {'ff': { + 's': 1, + 't': 2, + 'd': [make_simple_split('split1', 2, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + } + split_changes[2] = {'ff': {'s': 2, 't': 2, 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []}} + sse_server.publish(make_split_change_event(2)) + time.sleep(1) + assert factory.client().get_treatment('maldo', 'split1') == 'off' + + sse_server.publish(SSEMockServer.GRACEFUL_REQUEST_END) + time.sleep(1) + assert factory.client().get_treatment('maldo', 'split1') == 'off' + assert task.running() + + time.sleep(2) # wait for the backoff to expire so streaming gets re-attached + + # re-send initial event AND occupancy + sse_server.publish(make_initial_event()) + sse_server.publish(make_occupancy('control_pri', 2)) + sse_server.publish(make_occupancy('control_sec', 2)) + time.sleep(2) + + assert not task.running() + split_changes[2] = {'ff': { + 's': 2, + 't': 3, + 'd': [make_simple_split('split1', 3, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + } + split_changes[3] = {'ff': {'s': 3, 't': 3, 'd': [], + 'rbs': {'t': -1, 's': -1, 'd': []}}} + sse_server.publish(make_split_change_event(3)) + time.sleep(1) + assert factory.client().get_treatment('maldo', 'split1') == 'on' + assert not task.running() + + # Validate the SSE requests + sse_request = sse_requests.get() + assert sse_request.method == 'GET' + path, qs = sse_request.path.split('?', 1) + assert path == '/event-stream' + qs = parse_qs(qs) + assert qs['accessToken'][0] == ( + 'eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05' + 'US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UW' + 'XlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjc' + 'mliZVwiXSxcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcI' + 'jpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY' + '2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJzXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzd' + 'WJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRhdGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFib' + 'HktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4cCI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0M' + 'Dk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5EvJh17WlOlAKhcD0' + ) + + assert set(qs['channels'][0].split(',')) == set(['MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_splits', + 'MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_segments', + '[?occupancy=metrics.publishers]control_pri', + '[?occupancy=metrics.publishers]control_sec']) + assert qs['v'][0] == '1.1' + + sse_request = sse_requests.get() + assert sse_request.method == 'GET' + path, qs = sse_request.path.split('?', 1) + assert path == '/event-stream' + qs = parse_qs(qs) + assert qs['accessToken'][0] == ( + 'eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05' + 'US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UW' + 'XlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjc' + 'mliZVwiXSxcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcI' + 'jpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY' + '2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJzXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzd' + 'WJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRhdGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFib' + 'HktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4cCI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0M' + 'Dk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5EvJh17WlOlAKhcD0' + ) + + assert set(qs['channels'][0].split(',')) == set(['MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_splits', + 'MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_segments', + '[?occupancy=metrics.publishers]control_pri', + '[?occupancy=metrics.publishers]control_sec']) + assert qs['v'][0] == '1.1' + + # Initial splits fetch + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=-1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Auth + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/v2/auth?s=1.3' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll after streaming connected + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Fetch after first notification + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll on retryable error handling + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Auth after connection breaks + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/v2/auth?s=1.3' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll after streaming connected again + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Fetch after new notification + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Cleanup + destroy_event = threading.Event() + factory.destroy(destroy_event) + destroy_event.wait() + sse_server.publish(sse_server.GRACEFUL_REQUEST_END) + sse_server.stop() + split_backend.stop() + + def test_ably_errors_handling(self): + """Test incoming ably errors and validate its handling.""" + import logging + logger = logging.getLogger('splitio') + handler = logging.StreamHandler() + formatter = logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(logging.DEBUG) + auth_server_response = { + 'pushEnabled': True, + 'token': ('eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.' + 'eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pO' + 'RFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjcmliZVwiXSxcIk1UWXlNVGN4T1RRNE13P' + 'T1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcIjpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm' + '9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJ' + 'zXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzdWJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRh' + 'dGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFibHktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4c' + 'CI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0MDk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5E' + 'vJh17WlOlAKhcD0') + } + + split_changes = { + -1: {'ff': { + 's': -1, + 't': 1, + 'd': [make_simple_split('split1', 1, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + }, + 1: {'ff': {'s': 1, 't': 1, 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []}} + } + + segment_changes = {} + split_backend_requests = Queue() + split_backend = SplitMockServer(split_changes, segment_changes, split_backend_requests, + auth_server_response) + sse_requests = Queue() + sse_server = SSEMockServer(sse_requests) + + split_backend.start() + sse_server.start() + sse_server.publish(make_initial_event()) + sse_server.publish(make_occupancy('control_pri', 2)) + sse_server.publish(make_occupancy('control_sec', 2)) + + kwargs = { + 'sdk_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'events_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'auth_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'streaming_api_base_url': 'http://localhost:%d' % sse_server.port(), + 'config': {'connectTimeout': 10000, 'featuresRefreshRate': 10} + } + + factory = get_factory('some_apikey', **kwargs) + factory.block_until_ready(1) + assert factory.ready + time.sleep(2) + + # Get a hook of the task so we can query its status + task = factory._sync_manager._synchronizer._split_tasks.split_task._task # pylint:disable=protected-access + assert not task.running() + + assert factory.client().get_treatment('maldo', 'split1') == 'on' + + # Make a change in the BE but don't send the event. + # We'll send an ignorable error and check it has nothing happened + split_changes[1] = {'ff': { + 's': 1, + 't': 2, + 'd': [make_simple_split('split1', 2, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + } + split_changes[2] = {'ff': {'s': 2, 't': 2, 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []}} + + sse_server.publish(make_ably_error_event(60000, 600)) + time.sleep(1) + assert factory.client().get_treatment('maldo', 'split1') == 'on' + assert not task.running() + + sse_server.publish(make_ably_error_event(40145, 401)) + sse_server.publish(sse_server.GRACEFUL_REQUEST_END) + time.sleep(3) + assert task.running() + assert factory.client().get_treatment('maldo', 'split1') == 'off' + + # Re-publish initial events so that the retry succeeds + sse_server.publish(make_initial_event()) + sse_server.publish(make_occupancy('control_pri', 2)) + sse_server.publish(make_occupancy('control_sec', 2)) + time.sleep(3) + assert not task.running() + + # Assert streaming is working properly + split_changes[2] = {'ff': { + 's': 2, + 't': 3, + 'd': [make_simple_split('split1', 3, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + } + split_changes[3] = {'ff': {'s': 3, 't': 3, 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []}} + sse_server.publish(make_split_change_event(3)) + time.sleep(2) + assert factory.client().get_treatment('maldo', 'split1') == 'on' + assert not task.running() + + # Send a non-retryable ably error + sse_server.publish(make_ably_error_event(40200, 402)) + sse_server.publish(sse_server.GRACEFUL_REQUEST_END) + time.sleep(3) + + # Assert sync-task is running and the streaming status handler thread is over + assert task.running() + assert 'PushStatusHandler' not in [t.name for t in threading.enumerate()] + + # Validate the SSE requests + sse_request = sse_requests.get() + assert sse_request.method == 'GET' + path, qs = sse_request.path.split('?', 1) + assert path == '/event-stream' + qs = parse_qs(qs) + assert qs['accessToken'][0] == ( + 'eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05' + 'US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UW' + 'XlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjc' + 'mliZVwiXSxcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcI' + 'jpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY' + '2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJzXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzd' + 'WJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRhdGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFib' + 'HktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4cCI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0M' + 'Dk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5EvJh17WlOlAKhcD0' + ) + + assert set(qs['channels'][0].split(',')) == set(['MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_splits', + 'MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_segments', + '[?occupancy=metrics.publishers]control_pri', + '[?occupancy=metrics.publishers]control_sec']) + assert qs['v'][0] == '1.1' + + assert sse_request.method == 'GET' + path, qs = sse_request.path.split('?', 1) + assert path == '/event-stream' + qs = parse_qs(qs) + assert qs['accessToken'][0] == ( + 'eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05' + 'US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UW' + 'XlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjc' + 'mliZVwiXSxcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcI' + 'jpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY' + '2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJzXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzd' + 'WJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRhdGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFib' + 'HktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4cCI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0M' + 'Dk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5EvJh17WlOlAKhcD0' + ) + + assert set(qs['channels'][0].split(',')) == set(['MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_splits', + 'MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_segments', + '[?occupancy=metrics.publishers]control_pri', + '[?occupancy=metrics.publishers]control_sec']) + assert qs['v'][0] == '1.1' + + # Initial splits fetch + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=-1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Auth + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/v2/auth?s=1.3' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll after streaming connected + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll retriable error + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Auth again + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/v2/auth?s=1.3' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll after push is up + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Fetch after notification + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Iteration until since == till + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # SyncAll after non recoverable ably error + req = split_backend_requests.get() + assert req.method == 'GET' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' + assert req.headers['authorization'] == 'Bearer some_apikey' + + # Cleanup + destroy_event = threading.Event() + factory.destroy(destroy_event) + destroy_event.wait() + sse_server.publish(sse_server.GRACEFUL_REQUEST_END) + sse_server.stop() + split_backend.stop() + + def test_change_number(mocker): + # test if changeNumber is missing + auth_server_response = { + 'pushEnabled': True, + 'token': ('eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.' + 'eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pO' + 'RFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjcmliZVwiXSxcIk1UWXlNVGN4T1RRNE13P' + 'T1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcIjpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm' + '9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJ' + 'zXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzdWJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRh' + 'dGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFibHktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4c' + 'CI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0MDk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5E' + 'vJh17WlOlAKhcD0') + } + + split_changes = { + -1: {'ff': { + 's': -1, + 't': 1, + 'd': [make_simple_split('split1', 1, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + }, + 1: {'ff': {'s': 1, 't': 1, 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []}} + } + + segment_changes = {} + split_backend_requests = Queue() + split_backend = SplitMockServer(split_changes, segment_changes, split_backend_requests, + auth_server_response) + sse_requests = Queue() + sse_server = SSEMockServer(sse_requests) + split_backend.start() + sse_server.start() + sse_server.publish(make_initial_event()) + sse_server.publish(make_occupancy('control_pri', 2)) + sse_server.publish(make_occupancy('control_sec', 2)) -class StreamingIntegrationTests(object): + kwargs = { + 'sdk_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'events_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'auth_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'streaming_api_base_url': 'http://localhost:%d' % sse_server.port(), + 'config': {'connectTimeout': 10000, 'featuresRefreshRate': 10} + } + + factory = get_factory('some_apikey', **kwargs) + factory.block_until_ready(1) + assert factory.ready + time.sleep(2) + + split_changes = make_split_fast_change_event(5).copy() + data = json.loads(split_changes['data']) + inner_data = json.loads(data['data']) + inner_data['changeNumber'] = None + data['data'] = json.dumps(inner_data) + split_changes['data'] = json.dumps(data) + sse_server.publish(split_changes) + time.sleep(1) + assert factory._storages['splits'].get_change_number() == 1 + + # Cleanup + destroy_event = threading.Event() + factory.destroy(destroy_event) + destroy_event.wait() + sse_server.publish(sse_server.GRACEFUL_REQUEST_END) + sse_server.stop() + split_backend.stop() + + +class StreamingIntegrationAsyncTests(object): """Test streaming operation and failover.""" - def test_happiness(self): + update_flag = False + metadata = [] + + @pytest.mark.asyncio + async def test_happiness(self): """Test initialization & splits/segment updates.""" auth_server_response = { 'pushEnabled': True, @@ -29,15 +1387,17 @@ def test_happiness(self): } split_changes = { - -1: { - 'since': -1, - 'till': 1, - 'splits': [make_simple_split('split1', 1, True, False, 'on', 'user', True)] + -1: {'ff': { + 's': -1, + 't': 1, + 'd': [make_simple_split('split1', 1, True, False, 'on', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} }, - 1: { - 'since': 1, - 'till': 1, - 'splits': [] + 1: {'ff': { + 's': 1, + 't': 1, + 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []} } } @@ -62,29 +1422,43 @@ def test_happiness(self): 'config': {'connectTimeout': 10000} } - factory = get_factory('some_apikey', **kwargs) - factory.block_until_ready(1) + factory = await get_factory_async('some_apikey', **kwargs) + await factory.block_until_ready(1) + await factory.client().on(SdkEvent.SDK_UPDATE, self._update_callcack) assert factory.ready - assert factory.client().get_treatment('maldo', 'split1') == 'on' - - time.sleep(1) - split_changes[1] = { - 'since': 1, - 'till': 2, - 'splits': [make_simple_split('split1', 2, True, False, 'off', 'user', False)] + assert await factory.client().get_treatment('maldo', 'split1') == 'on' + + await asyncio.sleep(1) + assert(factory._telemetry_evaluation_producer._telemetry_storage._streaming_events._streaming_events[len(factory._telemetry_evaluation_producer._telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.SYNC_MODE_UPDATE.value) + assert(factory._telemetry_evaluation_producer._telemetry_storage._streaming_events._streaming_events[len(factory._telemetry_evaluation_producer._telemetry_storage._streaming_events._streaming_events)-1]._data == SSESyncMode.STREAMING.value) + split_changes[1] = {'ff': { + 's': 1, + 't': 2, + 'd': [make_simple_split('split1', 2, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[2] = {'since': 2, 'till': 2, 'splits': []} + split_changes[2] = {'ff': {'s': 2, 't': 2, 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []}} sse_server.publish(make_split_change_event(2)) - time.sleep(1) - assert factory.client().get_treatment('maldo', 'split1') == 'off' - - split_changes[2] = { - 'since': 2, - 'till': 3, - 'splits': [make_split_with_segment('split2', 2, True, False, - 'off', 'user', 'off', 'segment1')] + await asyncio.sleep(1) + flag = False + for meta in self.metadata: + if 'split1' in meta.get_names(): + assert meta.get_type() == SdkEventType.FLAG_UPDATE + flag = True + assert flag + + assert await factory.client().get_treatment('maldo', 'split1') == 'off' + + split_changes[2] = {'ff': { + 's': 2, + 't': 3, + 'd': [make_split_with_segment('split2', 2, True, False, + 'off', 'user', 'off', 'segment1')]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[3] = {'since': 3, 'till': 3, 'splits': []} + split_changes[3] = {'ff': {'s': 3, 't': 3, 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []}} segment_changes[('segment1', -1)] = { 'name': 'segment1', 'added': ['maldo'], @@ -96,12 +1470,12 @@ def test_happiness(self): 'removed': [], 'since': 1, 'till': 1} sse_server.publish(make_split_change_event(3)) - time.sleep(1) + await asyncio.sleep(1) sse_server.publish(make_segment_change_event('segment1', 1)) - time.sleep(1) + await asyncio.sleep(1) - assert factory.client().get_treatment('pindon', 'split2') == 'off' - assert factory.client().get_treatment('maldo', 'split2') == 'on' + assert await factory.client().get_treatment('pindon', 'split2') == 'off' + assert await factory.client().get_treatment('maldo', 'split2') == 'on' # Validate the SSE request sse_request = sse_requests.get() @@ -127,58 +1501,52 @@ def test_happiness(self): '[?occupancy=metrics.publishers]control_sec']) assert qs['v'][0] == '1.1' - # Initial apikey validation - req = split_backend_requests.get() - assert req.method == 'GET' - assert req.path == '/api/segmentChanges/__SOME_INVALID_SEGMENT__?since=-1' - assert req.headers['authorization'] == 'Bearer some_apikey' - # Initial splits fetch req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=-1' + assert req.path == '/api/splitChanges?s=1.3&since=-1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Auth req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/v2/auth' + assert req.path == '/api/v2/auth?s=1.3' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll after streaming connected req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Fetch after first notification req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=2' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Fetch after second notification req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=2' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=3' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Segment change notification @@ -194,14 +1562,17 @@ def test_happiness(self): assert req.headers['authorization'] == 'Bearer some_apikey' # Cleanup - destroy_event = threading.Event() - factory.destroy(destroy_event) - destroy_event.wait() + await factory.destroy() sse_server.publish(sse_server.GRACEFUL_REQUEST_END) sse_server.stop() split_backend.stop() - def test_occupancy_flicker(self): + async def _update_callcack(self, metadata): + self.update_flag = True + self.metadata.append(metadata) + + @pytest.mark.asyncio + async def test_occupancy_flicker(self): """Test that changes in occupancy switch between polling & streaming properly.""" auth_server_response = { 'pushEnabled': True, @@ -217,12 +1588,14 @@ def test_occupancy_flicker(self): } split_changes = { - -1: { - 'since': -1, - 'till': 1, - 'splits': [make_simple_split('split1', 1, True, False, 'off', 'user', True)] + -1: {'ff': { + 's': -1, + 't': 1, + 'd': [make_simple_split('split1', 1, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} }, - 1: {'since': 1, 'till': 1, 'splits': []} + 1: {'ff': {'s': 1, 't': 1, 'd': []}, + 'rbs': {'t': -1, 's': -1, 'd': []}} } segment_changes = {} @@ -246,69 +1619,71 @@ def test_occupancy_flicker(self): 'config': {'connectTimeout': 10000, 'featuresRefreshRate': 10} } - factory = get_factory('some_apikey', **kwargs) - factory.block_until_ready(1) + factory = await get_factory_async('some_apikey', **kwargs) + await factory.block_until_ready(1) assert factory.ready - time.sleep(2) + await asyncio.sleep(2) # Get a hook of the task so we can query its status task = factory._sync_manager._synchronizer._split_tasks.split_task._task # pylint:disable=protected-access assert not task.running() - assert factory.client().get_treatment('maldo', 'split1') == 'on' + assert await factory.client().get_treatment('maldo', 'split1') == 'on' # Make a change in the BE but don't send the event. # After dropping occupancy, the sdk should switch to polling # and perform a syncAll that gets this change - split_changes[1] = { - 'since': 1, - 'till': 2, - 'splits': [make_simple_split('split1', 2, True, False, 'off', 'user', False)] + split_changes[1] = {'ff': { + 's': 1, + 't': 2, + 'd': [make_simple_split('split1', 2, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[2] = {'since': 2, 'till': 2, 'splits': []} - + split_changes[2] = {'ff': {'s': 2, 't': 2, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} sse_server.publish(make_occupancy('control_pri', 0)) sse_server.publish(make_occupancy('control_sec', 0)) - time.sleep(2) - assert factory.client().get_treatment('maldo', 'split1') == 'off' + await asyncio.sleep(2) + assert await factory.client().get_treatment('maldo', 'split1') == 'off' assert task.running() # We make another chagne in the BE and don't send the event. # We restore occupancy, and it should be fetched by the # sync all after streaming is restored. - split_changes[2] = { - 'since': 2, - 'till': 3, - 'splits': [make_simple_split('split1', 3, True, False, 'off', 'user', True)] + split_changes[2] = {'ff': { + 's': 2, + 't': 3, + 'd': [make_simple_split('split1', 3, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[3] = {'since': 3, 'till': 3, 'splits': []} - + split_changes[3] = {'ff': {'s': 3, 't': 3, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} sse_server.publish(make_occupancy('control_pri', 1)) - time.sleep(2) - assert factory.client().get_treatment('maldo', 'split1') == 'on' + await asyncio.sleep(2) + assert await factory.client().get_treatment('maldo', 'split1') == 'on' assert not task.running() # Now we make another change and send an event so it's propagated - split_changes[3] = { - 'since': 3, - 'till': 4, - 'splits': [make_simple_split('split1', 4, True, False, 'off', 'user', False)] + split_changes[3] = {'ff': { + 's': 3, + 't': 4, + 'd': [make_simple_split('split1', 4, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[4] = {'since': 4, 'till': 4, 'splits': []} + split_changes[4] = {'ff': {'s': 4, 't': 4, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} sse_server.publish(make_split_change_event(4)) - time.sleep(2) - assert factory.client().get_treatment('maldo', 'split1') == 'off' + await asyncio.sleep(2) + assert await factory.client().get_treatment('maldo', 'split1') == 'off' # Kill the split - split_changes[4] = { - 'since': 4, - 'till': 5, - 'splits': [make_simple_split('split1', 5, True, True, 'frula', 'user', False)] + split_changes[4] = {'ff': { + 's': 4, + 't': 5, + 'd': [make_simple_split('split1', 5, True, True, 'frula', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[5] = {'since': 5, 'till': 5, 'splits': []} + split_changes[5] = {'ff': {'s': 5, 't': 5, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} sse_server.publish(make_split_kill_event('split1', 'frula', 5)) - time.sleep(2) - assert factory.client().get_treatment('maldo', 'split1') == 'frula' + await asyncio.sleep(2) + assert await factory.client().get_treatment('maldo', 'split1') == 'frula' # Validate the SSE request sse_request = sse_requests.get() @@ -334,93 +1709,86 @@ def test_occupancy_flicker(self): '[?occupancy=metrics.publishers]control_sec']) assert qs['v'][0] == '1.1' - # Initial apikey validation - req = split_backend_requests.get() - assert req.method == 'GET' - assert req.path == '/api/segmentChanges/__SOME_INVALID_SEGMENT__?since=-1' - assert req.headers['authorization'] == 'Bearer some_apikey' - # Initial splits fetch req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=-1' + assert req.path == '/api/splitChanges?s=1.3&since=-1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Auth req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/v2/auth' + assert req.path == '/api/v2/auth?s=1.3' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll after streaming connected req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Fetch after first notification req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=2' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Fetch after second notification req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=2' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=3' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=3' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=4' + assert req.path == '/api/splitChanges?s=1.3&since=4&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Split kill req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=4' + assert req.path == '/api/splitChanges?s=1.3&since=4&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=5' + assert req.path == '/api/splitChanges?s=1.3&since=5&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Cleanup - destroy_event = threading.Event() - factory.destroy(destroy_event) - destroy_event.wait() + await factory.destroy() sse_server.publish(sse_server.GRACEFUL_REQUEST_END) sse_server.stop() split_backend.stop() - def test_start_without_occupancy(self): + @pytest.mark.asyncio + async def test_start_without_occupancy(self): """Test an SDK starting with occupancy on 0 and switching to streamin afterwards.""" auth_server_response = { 'pushEnabled': True, @@ -436,12 +1804,13 @@ def test_start_without_occupancy(self): } split_changes = { - -1: { - 'since': -1, - 'till': 1, - 'splits': [make_simple_split('split1', 1, True, False, 'off', 'user', True)] + -1: {'ff': { + 's': -1, + 't': 1, + 'd': [make_simple_split('split1', 1, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} }, - 1: {'since': 1, 'till': 1, 'splits': []} + 1: {'ff': {'s': 1, 't': 1, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} } segment_changes = {} @@ -465,29 +1834,33 @@ def test_start_without_occupancy(self): 'config': {'connectTimeout': 10000, 'featuresRefreshRate': 10} } - factory = get_factory('some_apikey', **kwargs) - factory.block_until_ready(1) + factory = await get_factory_async('some_apikey', **kwargs) + try: + await factory.block_until_ready(1) + except Exception: + pass assert factory.ready - time.sleep(2) + await asyncio.sleep(2) # Get a hook of the task so we can query its status task = factory._sync_manager._synchronizer._split_tasks.split_task._task # pylint:disable=protected-access assert task.running() - assert factory.client().get_treatment('maldo', 'split1') == 'on' + assert await factory.client().get_treatment('maldo', 'split1') == 'on' # Make a change in the BE but don't send the event. # After restoring occupancy, the sdk should switch to polling # and perform a syncAll that gets this change - split_changes[1] = { - 'since': 1, - 'till': 2, - 'splits': [make_simple_split('split1', 2, True, False, 'off', 'user', False)] + split_changes[1] = {'ff': { + 's': 1, + 't': 2, + 'd': [make_simple_split('split1', 2, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[2] = {'since': 2, 'till': 2, 'splits': []} + split_changes[2] = {'ff': {'s': 2, 't': 2, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} sse_server.publish(make_occupancy('control_sec', 1)) - time.sleep(2) - assert factory.client().get_treatment('maldo', 'split1') == 'off' + await asyncio.sleep(2) + assert await factory.client().get_treatment('maldo', 'split1') == 'off' assert not task.running() # Validate the SSE request @@ -514,63 +1887,56 @@ def test_start_without_occupancy(self): '[?occupancy=metrics.publishers]control_sec']) assert qs['v'][0] == '1.1' - # Initial apikey validation - req = split_backend_requests.get() - assert req.method == 'GET' - assert req.path == '/api/segmentChanges/__SOME_INVALID_SEGMENT__?since=-1' - assert req.headers['authorization'] == 'Bearer some_apikey' - # Initial splits fetch req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=-1' + assert req.path == '/api/splitChanges?s=1.3&since=-1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Auth req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/v2/auth' + assert req.path == '/api/v2/auth?s=1.3' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll after streaming connected req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll after push down req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll after push restored req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Second iteration of previous syncAll req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=2' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Cleanup - destroy_event = threading.Event() - factory.destroy(destroy_event) - destroy_event.wait() + await factory.destroy() sse_server.publish(sse_server.GRACEFUL_REQUEST_END) sse_server.stop() split_backend.stop() - def test_streaming_status_changes(self): + @pytest.mark.asyncio + async def test_streaming_status_changes(self): """Test changes between streaming enabled, paused and disabled.""" auth_server_response = { 'pushEnabled': True, @@ -586,12 +1952,13 @@ def test_streaming_status_changes(self): } split_changes = { - -1: { - 'since': -1, - 'till': 1, - 'splits': [make_simple_split('split1', 1, True, False, 'off', 'user', True)] + -1: {'ff': { + 's': -1, + 't': 1, + 'd': [make_simple_split('split1', 1, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} }, - 1: {'since': 1, 'till': 1, 'splits': []} + 1: {'ff': {'s': 1, 't': 1, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} } segment_changes = {} @@ -615,70 +1982,80 @@ def test_streaming_status_changes(self): 'config': {'connectTimeout': 10000, 'featuresRefreshRate': 10} } - factory = get_factory('some_apikey', **kwargs) - factory.block_until_ready(1) + factory = await get_factory_async('some_apikey', **kwargs) + try: + await factory.block_until_ready(1) + except Exception: + pass assert factory.ready - time.sleep(2) + await asyncio.sleep(2) # Get a hook of the task so we can query its status task = factory._sync_manager._synchronizer._split_tasks.split_task._task # pylint:disable=protected-access assert not task.running() - assert factory.client().get_treatment('maldo', 'split1') == 'on' + assert await factory.client().get_treatment('maldo', 'split1') == 'on' # Make a change in the BE but don't send the event. # After dropping occupancy, the sdk should switch to polling # and perform a syncAll that gets this change - split_changes[1] = { - 'since': 1, - 'till': 2, - 'splits': [make_simple_split('split1', 2, True, False, 'off', 'user', False)] + split_changes[1] = {'ff': { + 's': 1, + 't': 2, + 'd': [make_simple_split('split1', 2, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[2] = {'since': 2, 'till': 2, 'splits': []} + split_changes[2] = {'ff': {'s': 2, 't': 2, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} sse_server.publish(make_control_event('STREAMING_PAUSED', 1)) - time.sleep(2) - assert factory.client().get_treatment('maldo', 'split1') == 'off' + await asyncio.sleep(4) + + assert await factory.client().get_treatment('maldo', 'split1') == 'off' assert task.running() # We make another chagne in the BE and don't send the event. # We restore occupancy, and it should be fetched by the # sync all after streaming is restored. - split_changes[2] = { - 'since': 2, - 'till': 3, - 'splits': [make_simple_split('split1', 3, True, False, 'off', 'user', True)] + split_changes[2] = {'ff': { + 's': 2, + 't': 3, + 'd': [make_simple_split('split1', 3, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[3] = {'since': 3, 'till': 3, 'splits': []} + split_changes[3] = {'ff': {'s': 3, 't': 3, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} sse_server.publish(make_control_event('STREAMING_ENABLED', 2)) - time.sleep(2) - assert factory.client().get_treatment('maldo', 'split1') == 'on' + await asyncio.sleep(2) + + assert await factory.client().get_treatment('maldo', 'split1') == 'on' assert not task.running() # Now we make another change and send an event so it's propagated - split_changes[3] = { - 'since': 3, - 'till': 4, - 'splits': [make_simple_split('split1', 4, True, False, 'off', 'user', False)] + split_changes[3] = {'ff': { + 's': 3, + 't': 4, + 'd': [make_simple_split('split1', 4, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[4] = {'since': 4, 'till': 4, 'splits': []} + split_changes[4] = {'ff': {'s': 4, 't': 4, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} sse_server.publish(make_split_change_event(4)) - time.sleep(2) - assert factory.client().get_treatment('maldo', 'split1') == 'off' + await asyncio.sleep(2) + + assert await factory.client().get_treatment('maldo', 'split1') == 'off' assert not task.running() - split_changes[4] = { - 'since': 4, - 'till': 5, - 'splits': [make_simple_split('split1', 5, True, False, 'off', 'user', True)] + split_changes[4] = {'ff': { + 's': 4, + 't': 5, + 'd': [make_simple_split('split1', 5, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[5] = {'since': 5, 'till': 5, 'splits': []} + split_changes[5] = {'ff': {'s': 5, 't': 5, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} sse_server.publish(make_control_event('STREAMING_DISABLED', 2)) - time.sleep(2) - assert factory.client().get_treatment('maldo', 'split1') == 'on' + await asyncio.sleep(2) + + assert await factory.client().get_treatment('maldo', 'split1') == 'on' assert task.running() - assert 'PushStatusHandler' not in [t.name for t in threading.enumerate()] # Validate the SSE request sse_request = sse_requests.get() @@ -704,93 +2081,86 @@ def test_streaming_status_changes(self): '[?occupancy=metrics.publishers]control_sec']) assert qs['v'][0] == '1.1' - # Initial apikey validation - req = split_backend_requests.get() - assert req.method == 'GET' - assert req.path == '/api/segmentChanges/__SOME_INVALID_SEGMENT__?since=-1' - assert req.headers['authorization'] == 'Bearer some_apikey' - # Initial splits fetch req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=-1' + assert req.path == '/api/splitChanges?s=1.3&since=-1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Auth req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/v2/auth' + assert req.path == '/api/v2/auth?s=1.3' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll after streaming connected req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll on push down req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=2' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll after push is up req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=2' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=3' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Fetch after notification req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=3' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=4' + assert req.path == '/api/splitChanges?s=1.3&since=4&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll after streaming disabled req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=4' + assert req.path == '/api/splitChanges?s=1.3&since=4&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=5' + assert req.path == '/api/splitChanges?s=1.3&since=5&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Cleanup - destroy_event = threading.Event() - factory.destroy(destroy_event) - destroy_event.wait() + await factory.destroy() sse_server.publish(sse_server.GRACEFUL_REQUEST_END) sse_server.stop() split_backend.stop() - def test_server_closes_connection(self): + @pytest.mark.asyncio + async def test_server_closes_connection(self): """Test that if the server closes the connection, the whole flow is retried with BO.""" auth_server_response = { 'pushEnabled': True, @@ -806,16 +2176,13 @@ def test_server_closes_connection(self): } split_changes = { - -1: { - 'since': -1, - 'till': 1, - 'splits': [make_simple_split('split1', 1, True, False, 'on', 'user', True)] + -1: {'ff': { + 's': -1, + 't': 1, + 'd': [make_simple_split('split1', 1, True, False, 'on', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} }, - 1: { - 'since': 1, - 'till': 1, - 'splits': [] - } + 1: {'ff': {'s': 1, 't': 1, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} } segment_changes = {} @@ -840,48 +2207,51 @@ def test_server_closes_connection(self): 'segmentsRefreshRate': 100, 'metricsRefreshRate': 100, 'impressionsRefreshRate': 100, 'eventsPushRate': 100} } - - factory = get_factory('some_apikey', **kwargs) - factory.block_until_ready(1) + factory = await get_factory_async('some_apikey', **kwargs) + await factory.block_until_ready(1) assert factory.ready - assert factory.client().get_treatment('maldo', 'split1') == 'on' + assert await factory.client().get_treatment('maldo', 'split1') == 'on' task = factory._sync_manager._synchronizer._split_tasks.split_task._task # pylint:disable=protected-access assert not task.running() - time.sleep(1) - split_changes[1] = { - 'since': 1, - 'till': 2, - 'splits': [make_simple_split('split1', 2, True, False, 'off', 'user', False)] + await asyncio.sleep(1) + split_changes[1] = {'ff': { + 's': 1, + 't': 2, + 'd': [make_simple_split('split1', 2, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[2] = {'since': 2, 'till': 2, 'splits': []} + split_changes[2] = {'ff': {'s': 2, 't': 2, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} sse_server.publish(make_split_change_event(2)) - time.sleep(1) - assert factory.client().get_treatment('maldo', 'split1') == 'off' + await asyncio.sleep(1) + assert await factory.client().get_treatment('maldo', 'split1') == 'off' sse_server.publish(SSEMockServer.GRACEFUL_REQUEST_END) - time.sleep(1) - assert factory.client().get_treatment('maldo', 'split1') == 'off' + await asyncio.sleep(1) + assert await factory.client().get_treatment('maldo', 'split1') == 'off' assert task.running() - time.sleep(2) # wait for the backoff to expire so streaming gets re-attached +# # wait for the backoff to expire so streaming gets re-attached + await asyncio.sleep(2) # re-send initial event AND occupancy sse_server.publish(make_initial_event()) sse_server.publish(make_occupancy('control_pri', 2)) sse_server.publish(make_occupancy('control_sec', 2)) - time.sleep(2) + await asyncio.sleep(2) assert not task.running() - split_changes[2] = { - 'since': 2, - 'till': 3, - 'splits': [make_simple_split('split1', 3, True, False, 'off', 'user', True)] + split_changes[2] = {'ff': { + 's': 2, + 't': 3, + 'd': [make_simple_split('split1', 3, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[3] = {'since': 3, 'till': 3, 'splits': []} + split_changes[3] = {'ff': {'s': 3, 't': 3, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} sse_server.publish(make_split_change_event(3)) - time.sleep(1) - assert factory.client().get_treatment('maldo', 'split1') == 'on' + await asyncio.sleep(1) + + assert await factory.client().get_treatment('maldo', 'split1') == 'on' assert not task.running() # Validate the SSE requests @@ -931,87 +2301,80 @@ def test_server_closes_connection(self): '[?occupancy=metrics.publishers]control_sec']) assert qs['v'][0] == '1.1' - # Initial apikey validation - req = split_backend_requests.get() - assert req.method == 'GET' - assert req.path == '/api/segmentChanges/__SOME_INVALID_SEGMENT__?since=-1' - assert req.headers['authorization'] == 'Bearer some_apikey' - # Initial splits fetch req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=-1' + assert req.path == '/api/splitChanges?s=1.3&since=-1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Auth req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/v2/auth' + assert req.path == '/api/v2/auth?s=1.3' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll after streaming connected req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Fetch after first notification req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=2' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll on retryable error handling req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=2' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Auth after connection breaks req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/v2/auth' + assert req.path == '/api/v2/auth?s=1.3' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll after streaming connected again req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=2' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Fetch after new notification req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=2' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=3' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Cleanup - destroy_event = threading.Event() - factory.destroy(destroy_event) - destroy_event.wait() + await factory.destroy() sse_server.publish(sse_server.GRACEFUL_REQUEST_END) sse_server.stop() split_backend.stop() - def test_ably_errors_handling(self): + @pytest.mark.asyncio + async def test_ably_errors_handling(self): """Test incoming ably errors and validate its handling.""" import logging logger = logging.getLogger('splitio') @@ -1034,12 +2397,13 @@ def test_ably_errors_handling(self): } split_changes = { - -1: { - 'since': -1, - 'till': 1, - 'splits': [make_simple_split('split1', 1, True, False, 'off', 'user', True)] + -1: {'ff': { + 's': -1, + 't': 1, + 'd': [make_simple_split('split1', 1, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} }, - 1: {'since': 1, 'till': 1, 'splits': []} + 1: {'ff': {'s': 1, 't': 1, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} } segment_changes = {} @@ -1063,64 +2427,69 @@ def test_ably_errors_handling(self): 'config': {'connectTimeout': 10000, 'featuresRefreshRate': 10} } - factory = get_factory('some_apikey', **kwargs) - factory.block_until_ready(1) + factory = await get_factory_async('some_apikey', **kwargs) + try: + await factory.block_until_ready(5) + except Exception: + pass assert factory.ready - time.sleep(2) - + await asyncio.sleep(2) # Get a hook of the task so we can query its status task = factory._sync_manager._synchronizer._split_tasks.split_task._task # pylint:disable=protected-access assert not task.running() - assert factory.client().get_treatment('maldo', 'split1') == 'on' + assert await factory.client().get_treatment('maldo', 'split1') == 'on' # Make a change in the BE but don't send the event. # We'll send an ignorable error and check it has nothing happened - split_changes[1] = { - 'since': 1, - 'till': 2, - 'splits': [make_simple_split('split1', 2, True, False, 'off', 'user', False)] + split_changes[1] = {'ff': { + 's': 1, + 't': 2, + 'd': [make_simple_split('split1', 2, True, False, 'off', 'user', False)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[2] = {'since': 2, 'till': 2, 'splits': []} + split_changes[2] = {'ff': {'s': 2, 't': 2, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} sse_server.publish(make_ably_error_event(60000, 600)) - time.sleep(1) - assert factory.client().get_treatment('maldo', 'split1') == 'on' + await asyncio.sleep(1) + + assert await factory.client().get_treatment('maldo', 'split1') == 'on' assert not task.running() sse_server.publish(make_ably_error_event(40145, 401)) sse_server.publish(sse_server.GRACEFUL_REQUEST_END) - time.sleep(3) + await asyncio.sleep(3) + assert task.running() - assert factory.client().get_treatment('maldo', 'split1') == 'off' + assert await factory.client().get_treatment('maldo', 'split1') == 'off' # Re-publish initial events so that the retry succeeds sse_server.publish(make_initial_event()) sse_server.publish(make_occupancy('control_pri', 2)) sse_server.publish(make_occupancy('control_sec', 2)) - time.sleep(3) + await asyncio.sleep(3) assert not task.running() # Assert streaming is working properly - split_changes[2] = { - 'since': 2, - 'till': 3, - 'splits': [make_simple_split('split1', 3, True, False, 'off', 'user', True)] + split_changes[2] = {'ff': { + 's': 2, + 't': 3, + 'd': [make_simple_split('split1', 3, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} } - split_changes[3] = {'since': 3, 'till': 3, 'splits': []} + split_changes[3] = {'ff': {'s': 3, 't': 3, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} sse_server.publish(make_split_change_event(3)) - time.sleep(2) - assert factory.client().get_treatment('maldo', 'split1') == 'on' + await asyncio.sleep(2) + assert await factory.client().get_treatment('maldo', 'split1') == 'on' assert not task.running() # Send a non-retryable ably error sse_server.publish(make_ably_error_event(40200, 402)) sse_server.publish(sse_server.GRACEFUL_REQUEST_END) - time.sleep(3) + await asyncio.sleep(3) # Assert sync-task is running and the streaming status handler thread is over assert task.running() - assert 'PushStatusHandler' not in [t.name for t in threading.enumerate()] # Validate the SSE requests sse_request = sse_requests.get() @@ -1168,86 +2537,144 @@ def test_ably_errors_handling(self): '[?occupancy=metrics.publishers]control_sec']) assert qs['v'][0] == '1.1' - # Initial apikey validation - req = split_backend_requests.get() - assert req.method == 'GET' - assert req.path == '/api/segmentChanges/__SOME_INVALID_SEGMENT__?since=-1' - assert req.headers['authorization'] == 'Bearer some_apikey' - # Initial splits fetch req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=-1' + assert req.path == '/api/splitChanges?s=1.3&since=-1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Auth req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/v2/auth' + assert req.path == '/api/v2/auth?s=1.3' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll after streaming connected req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll retriable error req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=1' + assert req.path == '/api/splitChanges?s=1.3&since=1&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=2' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Auth again req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/v2/auth' + assert req.path == '/api/v2/auth?s=1.3' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll after push is up req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=2' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Fetch after notification req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=2' + assert req.path == '/api/splitChanges?s=1.3&since=2&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Iteration until since == till req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=3' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # SyncAll after non recoverable ably error req = split_backend_requests.get() assert req.method == 'GET' - assert req.path == '/api/splitChanges?since=3' + assert req.path == '/api/splitChanges?s=1.3&since=3&rbSince=-1' assert req.headers['authorization'] == 'Bearer some_apikey' # Cleanup - destroy_event = threading.Event() - factory.destroy(destroy_event) - destroy_event.wait() + await factory.destroy() sse_server.publish(sse_server.GRACEFUL_REQUEST_END) sse_server.stop() split_backend.stop() + @pytest.mark.asyncio + async def test_change_number(mocker): + # test if changeNumber is missing + auth_server_response = { + 'pushEnabled': True, + 'token': ('eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.' + 'eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk1UWXlNVGN4T1RRNE13PT1fTWpBNE16Y3pO' + 'RFUxTWc9PV9zZWdtZW50c1wiOltcInN1YnNjcmliZVwiXSxcIk1UWXlNVGN4T1RRNE13P' + 'T1fTWpBNE16Y3pORFUxTWc9PV9zcGxpdHNcIjpbXCJzdWJzY3JpYmVcIl0sXCJjb250cm' + '9sX3ByaVwiOltcInN1YnNjcmliZVwiLFwiY2hhbm5lbC1tZXRhZGF0YTpwdWJsaXNoZXJ' + 'zXCJdLFwiY29udHJvbF9zZWNcIjpbXCJzdWJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRh' + 'dGE6cHVibGlzaGVyc1wiXX0iLCJ4LWFibHktY2xpZW50SWQiOiJjbGllbnRJZCIsImV4c' + 'CI6MTYwNDEwMDU5MSwiaWF0IjoxNjA0MDk2OTkxfQ.aP9BfR534K6J9h8gfDWg_CQgpz5E' + 'vJh17WlOlAKhcD0') + } + + split_changes = { + -1: {'ff': { + 's': -1, + 't': 1, + 'd': [make_simple_split('split1', 1, True, False, 'off', 'user', True)]}, + 'rbs': {'t': -1, 's': -1, 'd': []} + }, + 1: {'ff': {'s': 1, 't': 1, 'd': []}, 'rbs': {'t': -1, 's': -1, 'd': []}} + } + + segment_changes = {} + split_backend_requests = Queue() + split_backend = SplitMockServer(split_changes, segment_changes, split_backend_requests, + auth_server_response) + sse_requests = Queue() + sse_server = SSEMockServer(sse_requests) + + split_backend.start() + sse_server.start() + sse_server.publish(make_initial_event()) + sse_server.publish(make_occupancy('control_pri', 2)) + sse_server.publish(make_occupancy('control_sec', 2)) + + kwargs = { + 'sdk_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'events_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'auth_api_base_url': 'http://localhost:%d/api' % split_backend.port(), + 'streaming_api_base_url': 'http://localhost:%d' % sse_server.port(), + 'config': {'connectTimeout': 10000, 'featuresRefreshRate': 100} + } + factory2 = await get_factory_async('some_apikey', **kwargs) + await factory2.block_until_ready(1) + assert factory2.ready + await asyncio.sleep(2) + + split_changes = make_split_fast_change_event(5).copy() + data = json.loads(split_changes['data']) + inner_data = json.loads(data['data']) + inner_data['changeNumber'] = None + data['data'] = json.dumps(inner_data) + split_changes['data'] = json.dumps(data) + sse_server.publish(split_changes) + await asyncio.sleep(1) + assert await factory2._storages['splits'].get_change_number() == 1 + + # Cleanup + await factory2.destroy() + sse_server.publish(sse_server.VIOLENT_REQUEST_END) + sse_server.stop() + split_backend.stop() def make_split_change_event(change_number): """Make a split change event.""" @@ -1266,6 +2693,32 @@ def make_split_change_event(change_number): }) } +def make_split_fast_change_event(change_number): + """Make a split change event.""" + json1 = make_simple_split('split5', 1, True, False, 'off', 'user', True) + str1 = json.dumps(json1) + byt1 = bytes(str1, encoding='utf-8') + compressed = base64.b64encode(byt1) + final = compressed.decode('utf-8') + + return { + 'event': 'message', + 'data': json.dumps({ + 'id':'TVUsxaabHs:0:0', + 'clientId':'pri:MzM0ODI1MTkxMw==', + 'timestamp': change_number-1, + 'encoding':'json', + 'channel':'MTYyMTcxOTQ4Mw==_MjA4MzczNDU1Mg==_splits', + 'data': json.dumps({ + 'type': 'SPLIT_UPDATE', + 'changeNumber': change_number, + 'pcn': 3, + 'c': 0, + 'd': final + }) + }) + } + def make_split_kill_event(name, default_treatment, change_number): """Make a split change event.""" return { @@ -1412,6 +2865,23 @@ def make_split_with_segment(name, cn, active, killed, default_treatment, 'treatment': 'on' if on else 'off', 'size': 100 }] + }, + { + 'matcherGroup': { + 'combiner': 'AND', + 'matchers': [ + { + 'matcherType': 'ALL_KEYS', + 'negate': False, + 'userDefinedSegmentMatcherData': None, + 'whitelistMatcherData': None + } + ] + }, + 'partitions': [ + {'treatment': 'on' if on else 'off', 'size': 0}, + {'treatment': 'off' if on else 'on', 'size': 100} + ] } ] } diff --git a/tests/models/grammar/files/between-semver.csv b/tests/models/grammar/files/between-semver.csv new file mode 100644 index 00000000..71bdf3b2 --- /dev/null +++ b/tests/models/grammar/files/between-semver.csv @@ -0,0 +1,18 @@ +version1,version2,version3,expected +1.1.1,2.2.2,3.3.3,true +1.1.1-rc.1,1.1.1-rc.2,1.1.1-rc.3,true +1.0.0-alpha,1.0.0-alpha.1,1.0.0-alpha.beta,true +1.0.0-alpha.1,1.0.0-alpha.beta,1.0.0-beta,true +1.0.0-alpha.beta,1.0.0-beta,1.0.0-beta.2,true +1.0.0-beta,1.0.0-beta.2,1.0.0-beta.11,true +1.0.0-beta.2,1.0.0-beta.11,1.0.0-rc.1,true +1.0.0-beta.11,1.0.0-rc.1,1.0.0,true +1.1.2,1.1.3,1.1.4,true +1.2.1,1.3.1,1.4.1,true +2.0.0,3.0.0,4.0.0,true +2.2.2,2.2.3-rc1,2.2.3,true +2.2.2,2.3.2-rc100,2.3.3,true +1.0.0-rc.1+build.1,1.2.3-beta,1.2.3-rc.1+build.123,true +3.3.3,3.3.3-alpha,3.3.4,false +2.2.2-rc.1,2.2.2+metadata,2.2.2-rc.10,false +1.1.1-rc.1,1.1.1-rc.3,1.1.1-rc.2,false \ No newline at end of file diff --git a/tests/models/grammar/files/equal-to-semver.csv b/tests/models/grammar/files/equal-to-semver.csv new file mode 100644 index 00000000..87d8db5a --- /dev/null +++ b/tests/models/grammar/files/equal-to-semver.csv @@ -0,0 +1,7 @@ +version1,version2,equals +1.1.1,1.1.1,true +1.1.1,1.1.1+metadata,false +1.1.1,1.1.1-rc.1,false +88.88.88,88.88.88,true +1.2.3----RC-SNAPSHOT.12.9.1--.12,1.2.3----RC-SNAPSHOT.12.9.1--.12,true +10.2.3-DEV-SNAPSHOT,10.2.3-SNAPSHOT-123,false \ No newline at end of file diff --git a/tests/models/grammar/files/invalid-semantic-versions.csv b/tests/models/grammar/files/invalid-semantic-versions.csv new file mode 100644 index 00000000..7a7f9fbc --- /dev/null +++ b/tests/models/grammar/files/invalid-semantic-versions.csv @@ -0,0 +1,28 @@ +invalid +1 +1.2 +1.alpha.2 ++invalid +-invalid +-invalid+invalid +-invalid.01 +alpha +alpha.beta +alpha.beta.1 +alpha.1 +alpha+beta +alpha_beta +alpha. +alpha.. +beta +-alpha. +1.2 +1.2.3.DEV +1.2-SNAPSHOT +1.2.31.2.3----RC-SNAPSHOT.12.09.1--..12+788 +1.2-RC-SNAPSHOT +-1.0.3-gamma+b7718 ++justmeta +1.1.1+ +1.1.1- +#99999999999999999999999.999999999999999999.99999999999999999----RC-SNAPSHOT.12.09.1--------------------------------..12 \ No newline at end of file diff --git a/tests/models/grammar/files/splits_prereq.json b/tests/models/grammar/files/splits_prereq.json new file mode 100644 index 00000000..5efa7fed --- /dev/null +++ b/tests/models/grammar/files/splits_prereq.json @@ -0,0 +1,293 @@ +{"ff": { + "d": [ + { + "trafficTypeName": "user", + "name": "test_prereq", + "prerequisites": [ + { "n": "feature_segment", "ts": ["off", "def_test"] }, + { "n": "rbs_flag", "ts": ["on"] } + ], + "trafficAllocation": 100, + "trafficAllocationSeed": 1582960494, + "seed": 1842944006, + "status": "ACTIVE", + "killed": false, + "defaultTreatment": "def_treatment", + "changeNumber": 1582741588594, + "algo": 2, + "configurations": {}, + "conditions": [ + { + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user", + "attribute": null + }, + "matcherType": "ALL_KEYS", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": null, + "unaryNumericMatcherData": null, + "betweenMatcherData": null, + "booleanMatcherData": null, + "dependencyMatcherData": null, + "stringMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ], + "label": "default rule" + } + ] + }, + { + "name":"feature_segment", + "trafficTypeId":"u", + "trafficTypeName":"User", + "trafficAllocation": 100, + "trafficAllocationSeed": 1582960494, + "seed":-1177551240, + "status":"ACTIVE", + "killed":false, + "defaultTreatment":"def_test", + "changeNumber": 1582741588594, + "algo": 2, + "configurations": {}, + "conditions":[ + { + "matcherGroup":{ + "combiner":"AND", + "matchers":[ + { + "matcherType":"IN_SEGMENT", + "negate":false, + "userDefinedSegmentMatcherData":{ + "segmentName":"segment-test" + }, + "whitelistMatcherData":null + } + ] + }, + "partitions":[ + { + "treatment":"on", + "size":100 + }, + { + "treatment":"off", + "size":0 + } + ], + "label": "default label" + } + ] + }, + { + "changeNumber": 10, + "trafficTypeName": "user", + "name": "rbs_flag", + "trafficAllocation": 100, + "trafficAllocationSeed": 1828377380, + "seed": -286617921, + "status": "ACTIVE", + "killed": false, + "defaultTreatment": "off", + "algo": 2, + "conditions": [ + { + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user" + }, + "matcherType": "IN_RULE_BASED_SEGMENT", + "negate": false, + "userDefinedSegmentMatcherData": { + "segmentName": "sample_rule_based_segment" + } + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ], + "label": "in rule based segment sample_rule_based_segment" + }, + { + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user" + }, + "matcherType": "ALL_KEYS", + "negate": false + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 0 + }, + { + "treatment": "off", + "size": 100 + } + ], + "label": "default rule" + } + ], + "configurations": {}, + "sets": [], + "impressionsDisabled": false + }, + { + "trafficTypeName": "user", + "name": "prereq_chain", + "prerequisites": [ + { "n": "test_prereq", "ts": ["on"] } + ], + "trafficAllocation": 100, + "trafficAllocationSeed": -2092979940, + "seed": 105482719, + "status": "ACTIVE", + "killed": false, + "defaultTreatment": "on_default", + "changeNumber": 1585948850109, + "algo": 2, + "configurations": {}, + "conditions": [ + { + "conditionType": "WHITELIST", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": null, + "matcherType": "WHITELIST", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": { + "whitelist": [ + "bilal@split.io" + ] + }, + "unaryNumericMatcherData": null, + "betweenMatcherData": null, + "booleanMatcherData": null, + "dependencyMatcherData": null, + "stringMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on_whitelist", + "size": 100 + } + ], + "label": "whitelisted" + }, + { + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user", + "attribute": null + }, + "matcherType": "ALL_KEYS", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": null, + "unaryNumericMatcherData": null, + "betweenMatcherData": null, + "booleanMatcherData": null, + "dependencyMatcherData": null, + "stringMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + }, + { + "treatment": "V1", + "size": 0 + } + ], + "label": "default rule" + } + ] + } + ], + "s": -1, + "t": 1585948850109 +}, "rbs":{"d": [ + { + "changeNumber": 5, + "name": "sample_rule_based_segment", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded":{ + "keys":["mauro@split.io","gaston@split.io"], + "segments":[] + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user", + "attribute": "email" + }, + "matcherType": "ENDS_WITH", + "negate": false, + "whitelistMatcherData": { + "whitelist": [ + "@split.io" + ] + } + } + ] + } + } + ] + }], "s": -1, "t": 1585948850109} +} diff --git a/tests/models/grammar/files/valid-semantic-versions.csv b/tests/models/grammar/files/valid-semantic-versions.csv new file mode 100644 index 00000000..f491e77f --- /dev/null +++ b/tests/models/grammar/files/valid-semantic-versions.csv @@ -0,0 +1,25 @@ +higher,lower +1.1.2,1.1.1 +1.0.0,1.0.0-rc.1 +1.1.0-rc.1,1.0.0-beta.11 +1.0.0-beta.11,1.0.0-beta.2 +1.0.0-beta.2,1.0.0-beta +1.0.0-beta,1.0.0-alpha.beta +1.0.0-alpha.beta,1.0.0-alpha.1 +1.0.0-alpha.1,1.0.0-alpha +2.2.2-rc.2+metadata-lalala,2.2.2-rc.1.2 +1.2.3,0.0.4 +1.1.2+meta,1.1.2-prerelease+meta +1.0.0-beta,1.0.0-alpha +1.0.0-alpha0.valid,1.0.0-alpha.0valid +1.0.0-rc.1+build.1,1.0.0-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay +10.2.3-DEV-SNAPSHOT,1.2.3-SNAPSHOT-123 +1.1.1-rc2,1.0.0-0A.is.legal +1.2.3----RC-SNAPSHOT.12.9.1--.12+788,1.2.3----R-S.12.9.1--.12+meta +1.2.3----RC-SNAPSHOT.12.9.1--.12.88,1.2.3----RC-SNAPSHOT.12.9.1--.12 +9223372036854775807.9223372036854775807.9223372036854775807,9223372036854775807.9223372036854775807.9223372036854775806 +1.1.1-alpha.beta.rc.build.java.pr.support.10,1.1.1-alpha.beta.rc.build.java.pr.support +1.1.2,1.1.1 +1.2.1,1.1.1 +2.1.1,1.1.1 +1.1.1-rc.1,1.1.1-rc.0 \ No newline at end of file diff --git a/tests/models/grammar/test_matchers.py b/tests/models/grammar/test_matchers.py index f6f1c25a..71922431 100644 --- a/tests/models/grammar/test_matchers.py +++ b/tests/models/grammar/test_matchers.py @@ -6,13 +6,19 @@ import json import os.path import re +import pytest from datetime import datetime from splitio.models.grammar import matchers +from splitio.models.grammar.matchers.prerequisites import PrerequisitesMatcher +from splitio.models import splits +from splitio.models import rule_based_segments +from splitio.models.grammar import condition +from splitio.models.grammar.matchers.utils.utils import Semver from splitio.storage import SegmentStorage -from splitio.engine.evaluator import Evaluator - +from splitio.engine.evaluator import Evaluator, EvaluationContext +from tests.integration import splits_json class MatcherTestsBase(object): """Abstract class to make sure we test all relevant methods.""" @@ -398,26 +404,11 @@ def test_from_raw(self, mocker): def test_matcher_behaviour(self, mocker): """Test if the matcher works properly.""" matcher = matchers.UserDefinedSegmentMatcher(self.raw) - segment_storage = mocker.Mock(spec=SegmentStorage) # Test that if the key if the storage wrapper finds the key in the segment, it matches. - segment_storage.segment_contains.return_value = True - assert matcher.evaluate('some_key', {}, {'segment_storage': segment_storage}) is True - + assert matcher.evaluate('some_key', {}, {'evaluator': None, 'ec': EvaluationContext([],{'some_segment': True}, {})}) is True # Test that if the key if the storage wrapper doesn't find the key in the segment, it fails. - segment_storage.segment_contains.return_value = False - assert matcher.evaluate('some_key', {}, {'segment_storage': segment_storage}) is False - - assert segment_storage.segment_contains.mock_calls == [ - mocker.call('some_segment', 'some_key'), - mocker.call('some_segment', 'some_key') - ] - - assert matcher.evaluate([], {}, {'segment_storage': segment_storage}) is False - assert matcher.evaluate({}, {}, {'segment_storage': segment_storage}) is False - assert matcher.evaluate(123, {}, {'segment_storage': segment_storage}) is False - assert matcher.evaluate(True, {}, {'segment_storage': segment_storage}) is False - assert matcher.evaluate(False, {}, {'segment_storage': segment_storage}) is False + assert matcher.evaluate('some_key', {}, {'evaluator': None, 'ec': EvaluationContext([], {'some_segment': False}, {})}) is False def test_to_json(self): """Test that the object serializes to JSON properly.""" @@ -784,30 +775,35 @@ def test_from_raw(self, mocker): def test_matcher_behaviour(self, mocker): """Test if the matcher works properly.""" - parsed = matchers.DependencyMatcher(self.raw) + cond_raw = self.raw.copy() + cond_raw['dependencyMatcherData']['split'] = 'SPLIT_2' + parsed = matchers.DependencyMatcher(cond_raw) evaluator = mocker.Mock(spec=Evaluator) - evaluator.evaluate_feature.return_value = {'treatment': 'on'} - assert parsed.evaluate('test1', {}, {'bucketing_key': 'buck', 'evaluator': evaluator}) is True + cond = condition.from_raw(splits_json["splitChange1_1"]['ff']['d'][0]['conditions'][0]) + split = splits.from_raw(splits_json["splitChange1_1"]['ff']['d'][0]) + + evaluator.eval_with_context.return_value = {'treatment': 'on'} + assert parsed.evaluate('SPLIT_2', {}, {'evaluator': evaluator, 'ec': [{'flags': [split], 'segment_memberships': {}}]}) is True - evaluator.evaluate_feature.return_value = {'treatment': 'off'} - assert parsed.evaluate('test1', {}, {'bucketing_key': 'buck', 'evaluator': evaluator}) is False + evaluator.eval_with_context.return_value = {'treatment': 'off'} + assert parsed.evaluate('SPLIT_2', {}, {'evaluator': evaluator, 'ec': [{'flags': [split], 'segment_memberships': {}}]}) is False - assert evaluator.evaluate_feature.mock_calls == [ - mocker.call('some_split', 'test1', 'buck', {}), - mocker.call('some_split', 'test1', 'buck', {}) + assert evaluator.eval_with_context.mock_calls == [ + mocker.call('SPLIT_2', None, 'SPLIT_2', {}, [{'flags': [split], 'segment_memberships': {}}]), + mocker.call('SPLIT_2', None, 'SPLIT_2', {}, [{'flags': [split], 'segment_memberships': {}}]) ] - assert parsed.evaluate([], {}, {'bucketing_key': 'buck', 'evaluator': evaluator}) is False - assert parsed.evaluate({}, {}, {'bucketing_key': 'buck', 'evaluator': evaluator}) is False - assert parsed.evaluate(123, {}, {'bucketing_key': 'buck', 'evaluator': evaluator}) is False - assert parsed.evaluate(object(), {}, {'bucketing_key': 'buck', 'evaluator': evaluator}) is False + assert parsed.evaluate([], {}, {'evaluator': evaluator, 'ec': [{'flags': [split], 'segment_memberships': {}}]}) is False + assert parsed.evaluate({}, {}, {'evaluator': evaluator, 'ec': [{'flags': [split], 'segment_memberships': {}}]}) is False + assert parsed.evaluate(123, {}, {'evaluator': evaluator, 'ec': [{'flags': [split], 'segment_memberships': {}}]}) is False + assert parsed.evaluate(object(), {}, {'evaluator': evaluator, 'ec': [{'flags': [split], 'segment_memberships': {}}]}) is False def test_to_json(self): """Test that the object serializes to JSON properly.""" as_json = matchers.DependencyMatcher(self.raw).to_json() assert as_json['matcherType'] == 'IN_SPLIT_TREATMENT' - assert as_json['dependencyMatcherData']['split'] == 'some_split' + assert as_json['dependencyMatcherData']['split'] == 'SPLIT_2' assert as_json['dependencyMatcherData']['treatments'] == ['on', 'almost_on'] @@ -884,3 +880,290 @@ def test_to_json(self): as_json = matchers.RegexMatcher(self.raw).to_json() assert as_json['matcherType'] == 'MATCHES_STRING' assert as_json['stringMatcherData'] == "^[a-z][A-Z][0-9]$" + +class EqualToSemverMatcherTests(MatcherTestsBase): + """Semver equalto matcher test cases.""" + + raw = { + 'negate': False, + 'matcherType': 'EQUAL_TO_SEMVER', + 'stringMatcherData': "2.1.8" + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.EqualToSemverMatcher) + assert parsed._semver is not None + assert parsed._semver.version == "2.1.8" + assert isinstance(parsed._semver, Semver) + assert parsed._semver._major == 2 + assert parsed._semver._minor == 1 + assert parsed._semver._patch == 8 + assert parsed._semver._pre_release == [] + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + parsed = matchers.from_raw(self.raw) + assert not parsed._match("2.1.8+rc") + assert parsed._match("2.1.8") + assert not parsed._match("2.1.5") + assert not parsed._match("2.1.5-rc1") + assert not parsed._match(None) + assert not parsed._match("semver") + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.EqualToSemverMatcher(self.raw).to_json() + assert as_json['matcherType'] == 'EQUAL_TO_SEMVER' + assert as_json['stringMatcherData'] == "2.1.8" + + def test_to_str(self): + """Test that the object serializes to str properly.""" + as_str = matchers.EqualToSemverMatcher(self.raw) + assert str(as_str) == "equal semver 2.1.8" + +class GreaterThanOrEqualToSemverMatcherTests(MatcherTestsBase): + """Semver greater or equalto matcher test cases.""" + + raw = { + 'negate': False, + 'matcherType': 'GREATER_THAN_OR_EQUAL_TO_SEMVER', + 'stringMatcherData': "2.1.8" + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.GreaterThanOrEqualToSemverMatcher) + assert parsed._semver is not None + assert parsed._semver.version == "2.1.8" + assert isinstance(parsed._semver, Semver) + assert parsed._semver._major == 2 + assert parsed._semver._minor == 1 + assert parsed._semver._patch == 8 + assert parsed._semver._pre_release == [] + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + parsed = matchers.from_raw(self.raw) + assert parsed._match("2.1.8+rc") + assert parsed._match("2.1.8") + assert parsed._match("2.1.11") + assert not parsed._match("2.1.5") + assert not parsed._match("2.1.5-rc1") + assert not parsed._match(None) + assert not parsed._match("semver") + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.GreaterThanOrEqualToSemverMatcher(self.raw).to_json() + assert as_json['matcherType'] == 'GREATER_THAN_OR_EQUAL_TO_SEMVER' + assert as_json['stringMatcherData'] == "2.1.8" + + def test_to_str(self): + """Test that the object serializes to str properly.""" + as_str = matchers.GreaterThanOrEqualToSemverMatcher(self.raw) + assert str(as_str) == "greater than or equal to semver 2.1.8" + +class LessThanOrEqualToSemverMatcherTests(MatcherTestsBase): + """Semver less or equalto matcher test cases.""" + + raw = { + 'negate': False, + 'matcherType': 'LESS_THAN_OR_EQUAL_TO_SEMVER', + 'stringMatcherData': "2.1.8" + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.LessThanOrEqualToSemverMatcher) + assert parsed._semver is not None + assert parsed._semver.version == "2.1.8" + assert isinstance(parsed._semver, Semver) + assert parsed._semver._major == 2 + assert parsed._semver._minor == 1 + assert parsed._semver._patch == 8 + assert parsed._semver._pre_release == [] + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + parsed = matchers.from_raw(self.raw) + assert parsed._match("2.1.8+rc") + assert parsed._match("2.1.8") + assert not parsed._match("2.1.11") + assert parsed._match("2.1.5") + assert parsed._match("2.1.5-rc1") + assert not parsed._match(None) + assert not parsed._match("semver") + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.LessThanOrEqualToSemverMatcher(self.raw).to_json() + assert as_json['matcherType'] == 'LESS_THAN_OR_EQUAL_TO_SEMVER' + assert as_json['stringMatcherData'] == "2.1.8" + + def test_to_str(self): + """Test that the object serializes to str properly.""" + as_str = matchers.LessThanOrEqualToSemverMatcher(self.raw) + assert str(as_str) == "less than or equal to semver 2.1.8" + +class BetweenSemverMatcherTests(MatcherTestsBase): + """Semver between matcher test cases.""" + + raw = { + 'negate': False, + 'matcherType': 'BETWEEN_SEMVER', + 'betweenStringMatcherData': {"start": "2.1.8", "end": "2.1.11"} + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.BetweenSemverMatcher) + assert isinstance(parsed._semver_start, Semver) + assert isinstance(parsed._semver_end, Semver) + assert parsed._semver_start.version == "2.1.8" + assert parsed._semver_start._major == 2 + assert parsed._semver_start._minor == 1 + assert parsed._semver_start._patch == 8 + assert parsed._semver_start._pre_release == [] + + assert parsed._semver_end.version == "2.1.11" + assert parsed._semver_end._major == 2 + assert parsed._semver_end._minor == 1 + assert parsed._semver_end._patch == 11 + assert parsed._semver_end._pre_release == [] + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + parsed = matchers.from_raw(self.raw) + assert parsed._match("2.1.8+rc") + assert parsed._match("2.1.9") + assert parsed._match("2.1.11-rc12") + assert not parsed._match("2.1.5") + assert not parsed._match("2.1.12-rc1") + assert not parsed._match(None) + assert not parsed._match("semver") + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.BetweenSemverMatcher(self.raw).to_json() + assert as_json['matcherType'] == 'BETWEEN_SEMVER' + assert as_json['betweenStringMatcherData'] == {"start": "2.1.8", "end": "2.1.11"} + + def test_to_str(self): + """Test that the object serializes to str properly.""" + as_str = matchers.BetweenSemverMatcher(self.raw) + assert str(as_str) == "between semver 2.1.8 and 2.1.11" + +class InListSemverMatcherTests(MatcherTestsBase): + """Semver inlist matcher test cases.""" + + raw = { + 'negate': False, + 'matcherType': 'IN_LIST_SEMVER', + 'whitelistMatcherData': {"whitelist": ["2.1.8", "2.1.11"]} + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.InListSemverMatcher) + assert parsed._data == ["2.1.8", "2.1.11"] + assert [isinstance(item, str) for item in parsed._semver_list] + assert "2.1.8" in parsed._semver_list + assert "2.1.11" in parsed._semver_list + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + parsed = matchers.from_raw(self.raw) + assert not parsed._match("2.1.8+rc") + assert parsed._match("2.1.8") + assert not parsed._match("2.1.11-rc12") + assert parsed._match("2.1.11") + assert not parsed._match("2.1.7") + assert not parsed._match(None) + assert not parsed._match("semver") + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.InListSemverMatcher(self.raw).to_json() + assert as_json['matcherType'] == 'IN_LIST_SEMVER' + assert as_json['whitelistMatcherData'] == {"whitelist": ["2.1.8", "2.1.11"]} + + def test_to_str(self): + """Test that the object serializes to str properly.""" + as_str = matchers.InListSemverMatcher(self.raw) + assert str(as_str) == "in list semver ['2.1.8', '2.1.11']" + +class RuleBasedMatcherTests(MatcherTestsBase): + """Rule based segment matcher test cases.""" + + raw ={ + "keySelector": { + "trafficType": "user" + }, + "matcherType": "IN_RULE_BASED_SEGMENT", + "negate": False, + "userDefinedSegmentMatcherData": { + "segmentName": "sample_rule_based_segment" + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.RuleBasedSegmentMatcher) + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.AllKeysMatcher(self.raw).to_json() + assert as_json['matcherType'] == 'IN_RULE_BASED_SEGMENT' + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + rbs_segments = os.path.join(os.path.dirname(__file__), '../../engine/files', 'rule_base_segments3.json') + with open(rbs_segments, 'r') as flo: + data = json.loads(flo.read()) + + rbs = rule_based_segments.from_raw(data["rbs"]["d"][0]) + matcher = matchers.RuleBasedSegmentMatcher(self.raw) + ec ={'ec': EvaluationContext( + {}, + {"segment1": False}, + {"sample_rule_based_segment": rbs} + )} + assert matcher._match(None, context=ec) is False + assert matcher._match('bilal@split.io', context=ec) is False + assert matcher._match('bilal@split.io', {'email': 'bilal@split.io'}, context=ec) is True + +class PrerequisitesMatcherTests(MatcherTestsBase): + """tests for prerequisites matcher.""" + + def test_init(self, mocker): + """Test init.""" + split_load = os.path.join(os.path.dirname(__file__), 'files', 'splits_prereq.json') + with open(split_load, 'r') as flo: + data = json.loads(flo.read()) + + prereq = splits.from_raw_prerequisites(data['ff']['d'][0]['prerequisites']) + parsed = PrerequisitesMatcher(prereq) + assert parsed._prerequisites == prereq + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + split_load = os.path.join(os.path.dirname(__file__), 'files', 'splits_prereq.json') + with open(split_load, 'r') as flo: + data = json.loads(flo.read()) + prereq = splits.from_raw_prerequisites(data['ff']['d'][3]['prerequisites']) + parsed = PrerequisitesMatcher(prereq) + evaluator = mocker.Mock(spec=Evaluator) + + + evaluator.eval_with_context.return_value = {'treatment': 'on'} + assert parsed.match('SPLIT_2', {}, {'evaluator': evaluator, 'ec': [{'flags': ['prereq_chain'], 'segment_memberships': {}}]}) is True + + evaluator.eval_with_context.return_value = {'treatment': 'off'} + assert parsed.match('SPLIT_2', {}, {'evaluator': evaluator, 'ec': [{'flags': ['prereq_chain'], 'segment_memberships': {}}]}) is False \ No newline at end of file diff --git a/tests/models/grammar/test_semver.py b/tests/models/grammar/test_semver.py new file mode 100644 index 00000000..2a2b1b85 --- /dev/null +++ b/tests/models/grammar/test_semver.py @@ -0,0 +1,71 @@ +"""Condition model tests module.""" +import csv +import os + +from splitio.models.grammar.matchers.utils.utils import build_semver_or_none + +valid_versions = os.path.join(os.path.dirname(__file__), 'files', 'valid-semantic-versions.csv') +invalid_versions = os.path.join(os.path.dirname(__file__), 'files', 'invalid-semantic-versions.csv') +equalto_versions = os.path.join(os.path.dirname(__file__), 'files', 'equal-to-semver.csv') +between_versions = os.path.join(os.path.dirname(__file__), 'files', 'between-semver.csv') + +class SemverTests(object): + """Test the semver object model.""" + + def test_valid_versions(self): + with open(valid_versions) as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + assert build_semver_or_none(row['higher']) is not None + assert build_semver_or_none(row['lower']) is not None + + def test_invalid_versions(self): + with open(invalid_versions) as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + assert build_semver_or_none(row['invalid']) is None + + def test_compare(self): + with open(valid_versions) as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + higher = build_semver_or_none(row['higher']) + lower = build_semver_or_none(row['lower']) + assert higher is not None + assert lower is not None + assert higher.compare(lower) == 1 + assert lower.compare(higher) == -1 + + with open(equalto_versions) as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + version1 = build_semver_or_none(row['version1']) + version2 = build_semver_or_none(row['version2']) + assert version1 is not None + assert version2 is not None + if row['equals'] == "true": + assert version1.version == version2.version + else: + assert version1.version != version2.version + + with open(between_versions) as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + version1 = build_semver_or_none(row['version1']) + version2 = build_semver_or_none(row['version2']) + version3 = build_semver_or_none(row['version3']) + assert version1 is not None + assert version2 is not None + assert version3 is not None + if row['expected'] == "true": + assert version2.compare(version1) >= 0 and version3.compare(version2) >= 0 + else: + assert version2.compare(version1) < 0 or version3.compare(version2) < 0 + + def test_leading_zeros(self): + semver = build_semver_or_none('1.01.2') + assert semver is not None + assert semver.version == '1.1.2' + semver2 = build_semver_or_none('1.01.2-rc.01') + assert semver2 is not None + assert semver2.version == '1.1.2-rc.1' diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py new file mode 100644 index 00000000..aadb6007 --- /dev/null +++ b/tests/models/test_fallback.py @@ -0,0 +1,63 @@ +from splitio.models.fallback_treatment import FallbackTreatment +from splitio.models.fallback_config import FallbackTreatmentsConfiguration, FallbackTreatmentCalculator + +class FallbackTreatmentModelTests(object): + """Fallback treatment model tests.""" + + def test_working(self): + fallback_treatment = FallbackTreatment("on", '{"prop": "val"}') + assert fallback_treatment.config == '{"prop": "val"}' + assert fallback_treatment.treatment == 'on' + + fallback_treatment = FallbackTreatment("off") + assert fallback_treatment.config == None + assert fallback_treatment.treatment == 'off' + +class FallbackTreatmentsConfigModelTests(object): + """Fallback treatment configuration model tests.""" + + def test_working(self): + global_fb = FallbackTreatment("on") + flag_fb = FallbackTreatment("off") + fallback_config = FallbackTreatmentsConfiguration(global_fb, {"flag1": flag_fb}) + assert fallback_config.global_fallback_treatment == global_fb + assert fallback_config.by_flag_fallback_treatment == {"flag1": flag_fb} + + fallback_config.global_fallback_treatment = None + assert fallback_config.global_fallback_treatment == None + + fallback_config.by_flag_fallback_treatment["flag2"] = flag_fb + assert fallback_config.by_flag_fallback_treatment == {"flag1": flag_fb, "flag2": flag_fb} + + fallback_config = FallbackTreatmentsConfiguration("on", {"flag1": "off"}) + assert isinstance(fallback_config.global_fallback_treatment, FallbackTreatment) + assert fallback_config.global_fallback_treatment.treatment == "on" + + assert isinstance(fallback_config.by_flag_fallback_treatment["flag1"], FallbackTreatment) + assert fallback_config.by_flag_fallback_treatment["flag1"].treatment == "off" + + +class FallbackTreatmentCalculatorTests(object): + """Fallback treatment calculator model tests.""" + + def test_working(self): + fallback_config = FallbackTreatmentsConfiguration(FallbackTreatment("on" ,"{}"), None) + fallback_calculator = FallbackTreatmentCalculator(fallback_config) + assert fallback_calculator.fallback_treatments_configuration == fallback_config + assert fallback_calculator._label_prefix == "fallback - " + + fallback_treatment = fallback_calculator.resolve("feature", "not ready") + assert fallback_treatment.treatment == "on" + assert fallback_treatment.label == "fallback - not ready" + assert fallback_treatment.config == "{}" + + fallback_calculator._fallback_treatments_configuration = FallbackTreatmentsConfiguration(FallbackTreatment("on" ,"{}"), {'feature': FallbackTreatment("off" , '{"prop": "val"}')}) + fallback_treatment = fallback_calculator.resolve("feature", "not ready") + assert fallback_treatment.treatment == "off" + assert fallback_treatment.label == "fallback - not ready" + assert fallback_treatment.config == '{"prop": "val"}' + + fallback_treatment = fallback_calculator.resolve("feature2", "not ready") + assert fallback_treatment.treatment == "on" + assert fallback_treatment.label == "fallback - not ready" + assert fallback_treatment.config == "{}" diff --git a/tests/models/test_rule_based_segments.py b/tests/models/test_rule_based_segments.py new file mode 100644 index 00000000..98e35fe8 --- /dev/null +++ b/tests/models/test_rule_based_segments.py @@ -0,0 +1,103 @@ +"""Split model tests module.""" +import copy +from splitio.models import rule_based_segments +from splitio.models import splits +from splitio.models.grammar.condition import Condition +from splitio.models.grammar.matchers.rule_based_segment import RuleBasedSegmentMatcher + +class RuleBasedSegmentModelTests(object): + """Rule based segment model tests.""" + + raw = { + "changeNumber": 123, + "name": "sample_rule_based_segment", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded":{ + "keys":["mauro@split.io","gaston@split.io"], + "segments":[] + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user", + "attribute": "email" + }, + "matcherType": "ENDS_WITH", + "negate": False, + "whitelistMatcherData": { + "whitelist": [ + "@split.io" + ] + } + } + ] + } + } + ] + } + + def test_from_raw(self): + """Test split model parsing.""" + parsed = rule_based_segments.from_raw(self.raw) + assert isinstance(parsed, rule_based_segments.RuleBasedSegment) + assert parsed.change_number == 123 + assert parsed.name == 'sample_rule_based_segment' + assert parsed.status == splits.Status.ACTIVE + assert len(parsed.conditions) == 1 + assert parsed.excluded.get_excluded_keys() == ["mauro@split.io","gaston@split.io"] + assert parsed.excluded.get_excluded_segments() == [] + conditions = parsed.conditions[0].to_json() + assert conditions['matcherGroup']['matchers'][0] == { + 'betweenMatcherData': None, 'booleanMatcherData': None, 'dependencyMatcherData': None, + 'stringMatcherData': None, 'unaryNumericMatcherData': None, 'userDefinedSegmentMatcherData': None, + "keySelector": { + "attribute": "email" + }, + "matcherType": "ENDS_WITH", + "negate": False, + "whitelistMatcherData": { + "whitelist": [ + "@split.io" + ] + } + } + + def test_incorrect_matcher(self): + """Test incorrect matcher in split model parsing.""" + rbs = copy.deepcopy(self.raw) + rbs['conditions'][0]['matcherGroup']['matchers'][0]['matcherType'] = 'INVALID_MATCHER' + rbs = rule_based_segments.from_raw(rbs) + assert rbs.conditions[0].to_json() == splits._DEFAULT_CONDITIONS_TEMPLATE + + # using multiple conditions + rbs = copy.deepcopy(self.raw) + rbs['conditions'].append(rbs['conditions'][0]) + rbs['conditions'][0]['matcherGroup']['matchers'][0]['matcherType'] = 'INVALID_MATCHER' + parsed = rule_based_segments.from_raw(rbs) + assert parsed.conditions[0].to_json() == splits._DEFAULT_CONDITIONS_TEMPLATE + + def test_get_condition_segment_names(self): + rbs = copy.deepcopy(self.raw) + rbs['conditions'].append( + {"matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SEGMENT", + "negate": False, + "userDefinedSegmentMatcherData": { + "segmentName": "employees" + }, + "whitelistMatcherData": None + } + ] + }, + }) + rbs = rule_based_segments.from_raw(rbs) + + assert rbs.get_condition_segment_names() == {"employees"} \ No newline at end of file diff --git a/tests/models/test_splits.py b/tests/models/test_splits.py index 847448b0..472ecde9 100644 --- a/tests/models/test_splits.py +++ b/tests/models/test_splits.py @@ -1,9 +1,9 @@ """Split model tests module.""" +import copy from splitio.models import splits from splitio.models.grammar.condition import Condition - class SplitTests(object): """Split model tests.""" @@ -11,6 +11,10 @@ class SplitTests(object): 'changeNumber': 123, 'trafficTypeName': 'user', 'name': 'some_name', + 'prerequisites': [ + { 'n': 'flag1', 'ts': ['on','v1'] }, + { 'n': 'flag2', 'ts': ['off'] } + ], 'trafficAllocation': 100, 'trafficAllocationSeed': 123456, 'seed': 321654, @@ -60,6 +64,8 @@ class SplitTests(object): 'configurations': { 'on': '{"color": "blue", "size": 13}' }, + 'sets': ['set1', 'set2'], + 'impressionsDisabled': False } def test_from_raw(self): @@ -79,17 +85,30 @@ def test_from_raw(self): assert len(parsed.conditions) == 2 assert parsed.get_configurations_for('on') == '{"color": "blue", "size": 13}' assert parsed._configurations == {'on': '{"color": "blue", "size": 13}'} - + assert parsed.sets == {'set1', 'set2'} + assert parsed.impressions_disabled == False + assert len(parsed.prerequisites) == 2 + flag1 = False + flag2 = False + for prerequisite in parsed.prerequisites: + if prerequisite.feature_flag_name == 'flag1': + flag1 = True + assert prerequisite.treatments == ['on','v1'] + if prerequisite.feature_flag_name == 'flag2': + flag2 = True + assert prerequisite.treatments == ['off'] + assert flag1 + assert flag2 + def test_get_segment_names(self, mocker): """Test fetching segment names.""" cond1 = mocker.Mock(spec=Condition) cond2 = mocker.Mock(spec=Condition) cond1.get_segment_names.return_value = ['segment1', 'segment2'] cond2.get_segment_names.return_value = ['segment3', 'segment4'] - split1 = splits.Split( 'some_split', 123, False, 'off', 'user', 'ACTIVE', 123, [cond1, cond2]) + split1 = splits.Split( 'some_split', 123, False, 'off', 'user', 'ACTIVE', 123, [cond1, cond2], None) assert split1.get_segment_names() == ['segment%d' % i for i in range(1, 5)] - def test_to_json(self): """Test json serialization.""" as_json = splits.from_raw(self.raw).to_json() @@ -105,6 +124,8 @@ def test_to_json(self): assert as_json['defaultTreatment'] == 'off' assert as_json['algo'] == 2 assert len(as_json['conditions']) == 2 + assert sorted(as_json['sets']) == ['set1', 'set2'] + assert as_json['impressionsDisabled'] is False def test_to_split_view(self): """Test SplitView creation.""" @@ -115,3 +136,19 @@ def test_to_split_view(self): assert as_split_view.killed == self.raw['killed'] assert as_split_view.traffic_type == self.raw['trafficTypeName'] assert set(as_split_view.treatments) == set(['on', 'off']) + assert sorted(as_split_view.sets) == sorted(list(self.raw['sets'])) + assert as_split_view.impressions_disabled == self.raw['impressionsDisabled'] + + def test_incorrect_matcher(self): + """Test incorrect matcher in split model parsing.""" + split = copy.deepcopy(self.raw) + split['conditions'][0]['matcherGroup']['matchers'][0]['matcherType'] = 'INVALID_MATCHER' + parsed = splits.from_raw(split) + assert parsed.conditions[0].to_json() == splits._DEFAULT_CONDITIONS_TEMPLATE + + # using multiple conditions + split = copy.deepcopy(self.raw) + split['conditions'].append(split['conditions'][0]) + split['conditions'][0]['matcherGroup']['matchers'][0]['matcherType'] = 'INVALID_MATCHER' + parsed = splits.from_raw(split) + assert parsed.conditions[0].to_json() == splits._DEFAULT_CONDITIONS_TEMPLATE \ No newline at end of file diff --git a/tests/models/test_telemetry_model.py b/tests/models/test_telemetry_model.py new file mode 100644 index 00000000..7032c359 --- /dev/null +++ b/tests/models/test_telemetry_model.py @@ -0,0 +1,662 @@ +"""Telemetry model test module.""" +import os +import random +import pytest + +from splitio.models.telemetry import StorageType, OperationMode, MethodLatencies, MethodExceptions, \ + HTTPLatencies, HTTPErrors, LastSynchronization, TelemetryCounters, TelemetryConfig, \ + StreamingEvent, StreamingEvents, MethodExceptionsAsync, HTTPLatenciesAsync, HTTPErrorsAsync, LastSynchronizationAsync, \ + TelemetryCountersAsync, TelemetryConfigAsync, StreamingEventsAsync, MethodLatenciesAsync, UpdateFromSSE + +import splitio.models.telemetry as ModelTelemetry + +class TelemetryModelTests(object): + """Telemetry model test cases.""" + + def test_latency_bucket_index(self): + for i in range(50000): + latency = random.randint(10, 9987885) + old_bucket = 0 + result_bucket = 0 + counter = -1 + for j in ModelTelemetry.BUCKETS: + counter += 1 + if old_bucket == 0: + if latency < j: + old_bucket = 0 + break + old_bucket = j + continue + if counter == ModelTelemetry.MAX_LATENCY_BUCKET_COUNT - 1: + result_bucket = 22 + break + if latency > old_bucket and latency <= j: + result_bucket = counter + break + old_bucket = j + print(latency, old_bucket, j) + assert(result_bucket == ModelTelemetry.get_latency_bucket_index(latency)) + + def test_storage_type_and_operation_mode(self, mocker): + assert(StorageType.MEMORY.value == 'memory') + assert(StorageType.REDIS.value == 'redis') + assert(OperationMode.STANDALONE.value == 'standalone') + assert(OperationMode.CONSUMER.value == 'consumer') + + def test_method_latencies(self, mocker): + method_latencies = MethodLatencies() + + method_latencies.pop_all() # should not raise exception + for method in ModelTelemetry.MethodExceptionsAndLatencies: + method_latencies.add_latency(method, 50) + if method.value == 'treatment': + assert(method_latencies._treatment[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'treatments': + assert(method_latencies._treatments[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'treatment_with_config': + assert(method_latencies._treatment_with_config[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'treatments_with_config': + assert(method_latencies._treatments_with_config[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'treatments_by_flag_set': + assert(method_latencies._treatments_by_flag_set[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'treatments_by_flag_sets': + assert(method_latencies._treatments_by_flag_sets[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'treatments_with_config_by_flag_set': + assert(method_latencies._treatments_with_config_by_flag_set[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'treatments_with_config_by_flag_sets': + assert(method_latencies._treatments_with_config_by_flag_sets[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'track': + assert(method_latencies._track[ModelTelemetry.get_latency_bucket_index(50)] == 1) + + method_latencies.add_latency(method, 50000000) + if method.value == 'treatment': + assert(method_latencies._treatment[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + if method.value == 'treatments': + assert(method_latencies._treatments[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + if method.value == 'treatment_with_config': + assert(method_latencies._treatment_with_config[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + if method.value == 'treatments_with_config': + assert(method_latencies._treatments_with_config[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + elif method.value == 'treatments_by_flag_set': + assert(method_latencies._treatments_by_flag_set[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + elif method.value == 'treatments_by_flag_sets': + assert(method_latencies._treatments_by_flag_sets[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + elif method.value == 'treatments_with_config_by_flag_set': + assert(method_latencies._treatments_with_config_by_flag_set[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + elif method.value == 'treatments_with_config_by_flag_sets': + assert(method_latencies._treatments_with_config_by_flag_sets[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + if method.value == 'track': + assert(method_latencies._track[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + + method_latencies.pop_all() + assert(method_latencies._track == [0] * 23) + assert(method_latencies._treatment == [0] * 23) + assert(method_latencies._treatments == [0] * 23) + assert(method_latencies._treatment_with_config == [0] * 23) + assert(method_latencies._treatments_with_config == [0] * 23) + assert(method_latencies._treatments_by_flag_set == [0] * 23) + assert(method_latencies._treatments_by_flag_sets == [0] * 23) + assert(method_latencies._treatments_with_config_by_flag_set == [0] * 23) + assert(method_latencies._treatments_with_config_by_flag_sets == [0] * 23) + + method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT, 10) + [method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS, 20) for i in range(2)] + method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG, 50) + method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG, 20) + [method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET, 20) for i in range(3)] + [method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS, 20) for i in range(4)] + [method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET, 20) for i in range(5)] + [method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS, 20) for i in range(6)] + method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TRACK, 20) + latencies = method_latencies.pop_all() + assert(latencies == {'methodLatencies': {'treatment': [1] + [0] * 22, + 'treatments': [2] + [0] * 22, + 'treatment_with_config': [1] + [0] * 22, + 'treatments_with_config': [1] + [0] * 22, + 'treatments_by_flag_set': [3] + [0] * 22, + 'treatments_by_flag_sets': [4] + [0] * 22, + 'treatments_with_config_by_flag_set': [5] + [0] * 22, + 'treatments_with_config_by_flag_sets': [6] + [0] * 22, + 'track': [1] + [0] * 22}}) + + def test_http_latencies(self, mocker): + http_latencies = HTTPLatencies() + + http_latencies.pop_all() # should not raise exception + for resource in ModelTelemetry.HTTPExceptionsAndLatencies: + if self._get_http_latency(resource, http_latencies) == None: + continue + http_latencies.add_latency(resource, 50) + assert(self._get_http_latency(resource, http_latencies)[ModelTelemetry.get_latency_bucket_index(50)] == 1) + http_latencies.add_latency(resource, 50000000) + assert(self._get_http_latency(resource, http_latencies)[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + for j in range(10): + latency = random.randint(1001, 4987885) + current_count = self._get_http_latency(resource, http_latencies)[ModelTelemetry.get_latency_bucket_index(latency)] + [http_latencies.add_latency(resource, latency) for i in range(2)] + assert(self._get_http_latency(resource, http_latencies)[ModelTelemetry.get_latency_bucket_index(latency)] == 2 + current_count) + + http_latencies.pop_all() + assert(http_latencies._event == [0] * 23) + assert(http_latencies._impression == [0] * 23) + assert(http_latencies._impression_count == [0] * 23) + assert(http_latencies._segment == [0] * 23) + assert(http_latencies._split == [0] * 23) + assert(http_latencies._telemetry == [0] * 23) + assert(http_latencies._token == [0] * 23) + + http_latencies.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.SPLIT, 10) + [http_latencies.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION, i) for i in [10, 20]] + http_latencies.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT, 40) + http_latencies.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION_COUNT, 60) + http_latencies.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.EVENT, 90) + http_latencies.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.TELEMETRY, 70) + [http_latencies.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.TOKEN, i) for i in [10, 15]] + latencies = http_latencies.pop_all() + assert(latencies == {'httpLatencies': {'split': [1] + [0] * 22, 'segment': [1] + [0] * 22, 'impression': [2] + [0] * 22, 'impressionCount': [1] + [0] * 22, 'event': [1] + [0] * 22, 'telemetry': [1] + [0] * 22, 'token': [2] + [0] * 22}}) + + def _get_http_latency(self, resource, storage): + if resource == ModelTelemetry.HTTPExceptionsAndLatencies.SPLIT: + return storage._split + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT: + return storage._segment + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION: + return storage._impression + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION_COUNT: + return storage._impression_count + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.EVENT: + return storage._event + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.TELEMETRY: + return storage._telemetry + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.TOKEN: + return storage._token + else: + return + + def test_method_exceptions(self, mocker): + method_exception = MethodExceptions() + + exceptions = method_exception.pop_all() # should not raise exception + [method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT) for i in range(2)] + method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS) + method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG) + [method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG) for i in range(5)] + [method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET) for i in range(6)] + [method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS) for i in range(7)] + [method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET) for i in range(8)] + [method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS) for i in range(9)] + [method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TRACK) for i in range(3)] + exceptions = method_exception.pop_all() + + assert(method_exception._treatment == 0) + assert(method_exception._treatments == 0) + assert(method_exception._treatment_with_config == 0) + assert(method_exception._treatments_with_config == 0) + assert(method_exception._treatments_by_flag_set == 0) + assert(method_exception._treatments_by_flag_sets == 0) + assert(method_exception._treatments_with_config_by_flag_set == 0) + assert(method_exception._treatments_with_config_by_flag_sets == 0) + assert(method_exception._track == 0) + assert(exceptions == {'methodExceptions': {'treatment': 2, + 'treatments': 1, + 'treatment_with_config': 1, + 'treatments_with_config': 5, + 'treatments_by_flag_set': 6, + 'treatments_by_flag_sets': 7, + 'treatments_with_config_by_flag_set': 8, + 'treatments_with_config_by_flag_sets': 9, + 'track': 3}}) + + def test_http_errors(self, mocker): + http_error = HTTPErrors() + errors = http_error.pop_all() # should not raise exception + [http_error.add_error(ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT, str(i)) for i in [500, 501, 502]] + [http_error.add_error(ModelTelemetry.HTTPExceptionsAndLatencies.SPLIT, str(i)) for i in [400, 401, 402]] + http_error.add_error(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION, '502') + [http_error.add_error(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION_COUNT, str(i)) for i in [501, 502]] + http_error.add_error(ModelTelemetry.HTTPExceptionsAndLatencies.EVENT, '501') + http_error.add_error(ModelTelemetry.HTTPExceptionsAndLatencies.TELEMETRY, '505') + [http_error.add_error(ModelTelemetry.HTTPExceptionsAndLatencies.TOKEN, '502') for i in range(5)] + errors = http_error.pop_all() + assert(errors == {'httpErrors': {'split': {'400': 1, '401': 1, '402': 1}, 'segment': {'500': 1, '501': 1, '502': 1}, + 'impression': {'502': 1}, 'impressionCount': {'501': 1, '502': 1}, + 'event': {'501': 1}, 'telemetry': {'505': 1}, 'token': {'502': 5}}}) + assert(http_error._split == {}) + assert(http_error._segment == {}) + assert(http_error._impression == {}) + assert(http_error._impression_count == {}) + assert(http_error._event == {}) + assert(http_error._telemetry == {}) + + def test_last_synchronization(self, mocker): + last_synchronization = LastSynchronization() + last_synchronization.get_all() # should not raise exception + last_synchronization.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.SPLIT, 10) + last_synchronization.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION, 20) + last_synchronization.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT, 40) + last_synchronization.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION_COUNT, 60) + last_synchronization.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.EVENT, 90) + last_synchronization.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.TELEMETRY, 70) + last_synchronization.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.TOKEN, 15) + assert(last_synchronization.get_all() == {'lastSynchronizations': {'split': 10, 'segment': 40, 'impression': 20, 'impressionCount': 60, 'event': 90, 'telemetry': 70, 'token': 15}}) + + def test_telemetry_counters(self): + telemetry_counter = TelemetryCounters() + assert(telemetry_counter._impressions_queued == 0) + assert(telemetry_counter._impressions_deduped == 0) + assert(telemetry_counter._impressions_dropped == 0) + assert(telemetry_counter._events_dropped == 0) + assert(telemetry_counter._events_queued == 0) + assert(telemetry_counter._auth_rejections == 0) + assert(telemetry_counter._token_refreshes == 0) + assert(telemetry_counter._update_from_sse == {}) + + assert(telemetry_counter.get_session_length() == 0) + telemetry_counter.record_session_length(20) + assert(telemetry_counter.get_session_length() == 20) + + assert(telemetry_counter.pop_auth_rejections() == 0) + [telemetry_counter.record_auth_rejections() for i in range(5)] + auth_rejections = telemetry_counter.pop_auth_rejections() + assert(telemetry_counter._auth_rejections == 0) + assert(auth_rejections == 5) + + assert(telemetry_counter.pop_token_refreshes() == 0) + [telemetry_counter.record_token_refreshes() for i in range(3)] + token_refreshes = telemetry_counter.pop_token_refreshes() + assert(telemetry_counter._token_refreshes == 0) + assert(token_refreshes == 3) + + assert(telemetry_counter.get_counter_stats(ModelTelemetry.CounterConstants.IMPRESSIONS_QUEUED) == 0) + assert(telemetry_counter.get_counter_stats(ModelTelemetry.CounterConstants.IMPRESSIONS_DEDUPED) == 0) + assert(telemetry_counter.get_counter_stats(ModelTelemetry.CounterConstants.IMPRESSIONS_DROPPED) == 0) + assert(telemetry_counter.get_counter_stats(ModelTelemetry.CounterConstants.EVENTS_QUEUED) == 0) + assert(telemetry_counter.get_counter_stats(ModelTelemetry.CounterConstants.EVENTS_DROPPED) == 0) + telemetry_counter.record_impressions_value(ModelTelemetry.CounterConstants.IMPRESSIONS_QUEUED, 10) + assert(telemetry_counter._impressions_queued == 10) + telemetry_counter.record_impressions_value(ModelTelemetry.CounterConstants.IMPRESSIONS_DEDUPED, 14) + assert(telemetry_counter._impressions_deduped == 14) + telemetry_counter.record_impressions_value(ModelTelemetry.CounterConstants.IMPRESSIONS_DROPPED, 2) + assert(telemetry_counter._impressions_dropped == 2) + telemetry_counter.record_events_value(ModelTelemetry.CounterConstants.EVENTS_QUEUED, 30) + assert(telemetry_counter._events_queued == 30) + telemetry_counter.record_events_value(ModelTelemetry.CounterConstants.EVENTS_DROPPED, 1) + assert(telemetry_counter._events_dropped == 1) + telemetry_counter.record_update_from_sse(UpdateFromSSE.SPLIT_UPDATE) + assert(telemetry_counter._update_from_sse[UpdateFromSSE.SPLIT_UPDATE.value] == 1) + updates = telemetry_counter.pop_update_from_sse(UpdateFromSSE.SPLIT_UPDATE) + assert(telemetry_counter._update_from_sse[UpdateFromSSE.SPLIT_UPDATE.value] == 0) + assert(updates == 1) + + def test_streaming_event(self, mocker): + streaming_event = StreamingEvent((ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED, 'split', 1234)) + assert(streaming_event.type == ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED.value) + assert(streaming_event.data == 'split') + assert(streaming_event.time == 1234) + + def test_streaming_events(self, mocker): + streaming_events = StreamingEvents() + events = streaming_events.pop_streaming_events() # should not raise exception + streaming_events.record_streaming_event((ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED, 'split', 1234)) + streaming_events.record_streaming_event((ModelTelemetry.StreamingEventTypes.STREAMING_STATUS, 'split', 1234)) + events = streaming_events.pop_streaming_events() + assert(streaming_events._streaming_events == []) + assert(events == {'streamingEvents': [{'e': ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED.value, 'd': 'split', 't': 1234}, + {'e': ModelTelemetry.StreamingEventTypes.STREAMING_STATUS.value, 'd': 'split', 't': 1234}]}) + + def test_telemetry_config(self): + telemetry_config = TelemetryConfig() + stats = telemetry_config.get_stats() # should not raise exception + config = {'operationMode': 'standalone', + 'streamingEnabled': True, + 'impressionsQueueSize': 100, + 'eventsQueueSize': 200, + 'impressionsMode': 'DEBUG','' + 'impressionListener': None, + 'featuresRefreshRate': 30, + 'segmentsRefreshRate': 30, + 'impressionsRefreshRate': 60, + 'eventsPushRate': 60, + 'metricsRefreshRate': 10, + 'storageType': None, + 'flagSetsFilter': None + } + telemetry_config.record_config(config, {}, 5, 2) + assert(telemetry_config.get_stats() == {'oM': 0, + 'sT': telemetry_config._get_storage_type(config['operationMode'], config['storageType']), + 'sE': config['streamingEnabled'], + 'rR': {'sp': 30, 'se': 30, 'im': 60, 'ev': 60, 'te': 10}, + 'uO': {'s': False, 'e': False, 'a': False, 'st': False, 't': False}, + 'iQ': config['impressionsQueueSize'], + 'eQ': config['eventsQueueSize'], + 'iM': telemetry_config._get_impressions_mode(config['impressionsMode']), + 'iL': True if config['impressionListener'] is not None else False, + 'hp': telemetry_config._check_if_proxy_detected(), + 'tR': 0, + 'nR': 0, + 'bT': 0, + 'aF': 0, + 'rF': 0, + 'fsT': 5, + 'fsI': 2} + ) + + telemetry_config.record_ready_time(10) + assert(telemetry_config._time_until_ready == 10) + + assert(telemetry_config.get_bur_time_outs() == 0) + [telemetry_config.record_bur_time_out() for i in range(2)] + assert(telemetry_config.get_bur_time_outs() == 2) + + assert(telemetry_config.get_non_ready_usage() == 0) + [telemetry_config.record_not_ready_usage() for i in range(5)] + assert(telemetry_config.get_non_ready_usage() == 5) + + os.environ["https_proxy"] = "some_host_ip" + assert(telemetry_config._check_if_proxy_detected() == True) + + del os.environ["https_proxy"] + assert(telemetry_config._check_if_proxy_detected() == False) + + os.environ["HTTPS_proxy"] = "some_host_ip" + assert(telemetry_config._check_if_proxy_detected() == True) + + del os.environ["HTTPS_proxy"] + assert(telemetry_config._check_if_proxy_detected() == False) + +class TelemetryModelAsyncTests(object): + """Telemetry model async test cases.""" + + @pytest.mark.asyncio + async def test_method_latencies(self, mocker): + method_latencies = await MethodLatenciesAsync.create() + + for method in ModelTelemetry.MethodExceptionsAndLatencies: + await method_latencies.add_latency(method, 50) + if method.value == 'treatment': + assert(method_latencies._treatment[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'treatments': + assert(method_latencies._treatments[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'treatment_with_config': + assert(method_latencies._treatment_with_config[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'treatments_with_config': + assert(method_latencies._treatments_with_config[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'treatments_by_flag_set': + assert(method_latencies._treatments_by_flag_set[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'treatments_by_flag_sets': + assert(method_latencies._treatments_by_flag_sets[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'treatments_with_config_by_flag_set': + assert(method_latencies._treatments_with_config_by_flag_set[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'treatments_with_config_by_flag_sets': + assert(method_latencies._treatments_with_config_by_flag_sets[ModelTelemetry.get_latency_bucket_index(50)] == 1) + elif method.value == 'track': + assert(method_latencies._track[ModelTelemetry.get_latency_bucket_index(50)] == 1) + + await method_latencies.add_latency(method, 50000000) + if method.value == 'treatment': + assert(method_latencies._treatment[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + if method.value == 'treatments': + assert(method_latencies._treatments[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + if method.value == 'treatment_with_config': + assert(method_latencies._treatment_with_config[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + if method.value == 'treatments_with_config': + assert(method_latencies._treatments_with_config[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + elif method.value == 'treatments_by_flag_set': + assert(method_latencies._treatments_by_flag_set[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + elif method.value == 'treatments_by_flag_sets': + assert(method_latencies._treatments_by_flag_sets[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + elif method.value == 'treatments_with_config_by_flag_set': + assert(method_latencies._treatments_with_config_by_flag_set[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + elif method.value == 'treatments_with_config_by_flag_sets': + assert(method_latencies._treatments_with_config_by_flag_sets[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + if method.value == 'track': + assert(method_latencies._track[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + + await method_latencies.pop_all() + assert(method_latencies._track == [0] * 23) + assert(method_latencies._treatment == [0] * 23) + assert(method_latencies._treatments == [0] * 23) + assert(method_latencies._treatment_with_config == [0] * 23) + assert(method_latencies._treatments_with_config == [0] * 23) + assert(method_latencies._treatments_by_flag_set == [0] * 23) + assert(method_latencies._treatments_by_flag_sets == [0] * 23) + assert(method_latencies._treatments_with_config_by_flag_set == [0] * 23) + assert(method_latencies._treatments_with_config_by_flag_sets == [0] * 23) + + await method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT, 10) + [await method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS, 20) for i in range(2)] + await method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG, 50) + await method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG, 20) + [await method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET, 20) for i in range(3)] + [await method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS, 20) for i in range(4)] + [await method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET, 20) for i in range(5)] + [await method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS, 20) for i in range(6)] + await method_latencies.add_latency(ModelTelemetry.MethodExceptionsAndLatencies.TRACK, 20) + latencies = await method_latencies.pop_all() + assert(latencies == {'methodLatencies': {'treatment': [1] + [0] * 22, + 'treatments': [2] + [0] * 22, + 'treatment_with_config': [1] + [0] * 22, + 'treatments_with_config': [1] + [0] * 22, + 'treatments_by_flag_set': [3] + [0] * 22, + 'treatments_by_flag_sets': [4] + [0] * 22, + 'treatments_with_config_by_flag_set': [5] + [0] * 22, + 'treatments_with_config_by_flag_sets': [6] + [0] * 22, + 'track': [1] + [0] * 22}}) + + @pytest.mark.asyncio + async def test_http_latencies(self, mocker): + http_latencies = await HTTPLatenciesAsync.create() + + for resource in ModelTelemetry.HTTPExceptionsAndLatencies: + if self._get_http_latency(resource, http_latencies) == None: + continue + await http_latencies.add_latency(resource, 50) + assert(self._get_http_latency(resource, http_latencies)[ModelTelemetry.get_latency_bucket_index(50)] == 1) + await http_latencies.add_latency(resource, 50000000) + assert(self._get_http_latency(resource, http_latencies)[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + for j in range(10): + latency = random.randint(1001, 4987885) + current_count = self._get_http_latency(resource, http_latencies)[ModelTelemetry.get_latency_bucket_index(latency)] + [await http_latencies.add_latency(resource, latency) for i in range(2)] + assert(self._get_http_latency(resource, http_latencies)[ModelTelemetry.get_latency_bucket_index(latency)] == 2 + current_count) + + await http_latencies.pop_all() + assert(http_latencies._event == [0] * 23) + assert(http_latencies._impression == [0] * 23) + assert(http_latencies._impression_count == [0] * 23) + assert(http_latencies._segment == [0] * 23) + assert(http_latencies._split == [0] * 23) + assert(http_latencies._telemetry == [0] * 23) + assert(http_latencies._token == [0] * 23) + + await http_latencies.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.SPLIT, 10) + [await http_latencies.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION, i) for i in [10, 20]] + await http_latencies.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT, 40) + await http_latencies.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION_COUNT, 60) + await http_latencies.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.EVENT, 90) + await http_latencies.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.TELEMETRY, 70) + [await http_latencies.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.TOKEN, i) for i in [10, 15]] + latencies = await http_latencies.pop_all() + assert(latencies == {'httpLatencies': {'split': [1] + [0] * 22, 'segment': [1] + [0] * 22, 'impression': [2] + [0] * 22, 'impressionCount': [1] + [0] * 22, 'event': [1] + [0] * 22, 'telemetry': [1] + [0] * 22, 'token': [2] + [0] * 22}}) + + def _get_http_latency(self, resource, storage): + if resource == ModelTelemetry.HTTPExceptionsAndLatencies.SPLIT: + return storage._split + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT: + return storage._segment + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION: + return storage._impression + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION_COUNT: + return storage._impression_count + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.EVENT: + return storage._event + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.TELEMETRY: + return storage._telemetry + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.TOKEN: + return storage._token + else: + return + + @pytest.mark.asyncio + async def test_method_exceptions(self, mocker): + method_exception = await MethodExceptionsAsync.create() + + [await method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT) for i in range(2)] + await method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS) + await method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG) + [await method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG) for i in range(5)] + [await method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET) for i in range(6)] + [await method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS) for i in range(7)] + [await method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET) for i in range(8)] + [await method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS) for i in range(9)] + [await method_exception.add_exception(ModelTelemetry.MethodExceptionsAndLatencies.TRACK) for i in range(3)] + exceptions = await method_exception.pop_all() + + assert(method_exception._treatment == 0) + assert(method_exception._treatments == 0) + assert(method_exception._treatment_with_config == 0) + assert(method_exception._treatments_with_config == 0) + assert(method_exception._treatments_by_flag_set == 0) + assert(method_exception._treatments_by_flag_sets == 0) + assert(method_exception._treatments_with_config_by_flag_set == 0) + assert(method_exception._treatments_with_config_by_flag_sets == 0) + assert(method_exception._track == 0) + assert(exceptions == {'methodExceptions': {'treatment': 2, + 'treatments': 1, + 'treatment_with_config': 1, + 'treatments_with_config': 5, + 'treatments_by_flag_set': 6, + 'treatments_by_flag_sets': 7, + 'treatments_with_config_by_flag_set': 8, + 'treatments_with_config_by_flag_sets': 9, + 'track': 3}}) + + @pytest.mark.asyncio + async def test_http_errors(self, mocker): + http_error = await HTTPErrorsAsync.create() + [await http_error.add_error(ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT, str(i)) for i in [500, 501, 502]] + [await http_error.add_error(ModelTelemetry.HTTPExceptionsAndLatencies.SPLIT, str(i)) for i in [400, 401, 402]] + await http_error.add_error(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION, '502') + [await http_error.add_error(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION_COUNT, str(i)) for i in [501, 502]] + await http_error.add_error(ModelTelemetry.HTTPExceptionsAndLatencies.EVENT, '501') + await http_error.add_error(ModelTelemetry.HTTPExceptionsAndLatencies.TELEMETRY, '505') + [await http_error.add_error(ModelTelemetry.HTTPExceptionsAndLatencies.TOKEN, '502') for i in range(5)] + errors = await http_error.pop_all() + assert(errors == {'httpErrors': {'split': {'400': 1, '401': 1, '402': 1}, 'segment': {'500': 1, '501': 1, '502': 1}, + 'impression': {'502': 1}, 'impressionCount': {'501': 1, '502': 1}, + 'event': {'501': 1}, 'telemetry': {'505': 1}, 'token': {'502': 5}}}) + assert(http_error._split == {}) + assert(http_error._segment == {}) + assert(http_error._impression == {}) + assert(http_error._impression_count == {}) + assert(http_error._event == {}) + assert(http_error._telemetry == {}) + + @pytest.mark.asyncio + async def test_last_synchronization(self, mocker): + last_synchronization = await LastSynchronizationAsync.create() + await last_synchronization.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.SPLIT, 10) + await last_synchronization.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION, 20) + await last_synchronization.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT, 40) + await last_synchronization.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION_COUNT, 60) + await last_synchronization.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.EVENT, 90) + await last_synchronization.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.TELEMETRY, 70) + await last_synchronization.add_latency(ModelTelemetry.HTTPExceptionsAndLatencies.TOKEN, 15) + assert(await last_synchronization.get_all() == {'lastSynchronizations': {'split': 10, 'segment': 40, 'impression': 20, 'impressionCount': 60, 'event': 90, 'telemetry': 70, 'token': 15}}) + + @pytest.mark.asyncio + async def test_telemetry_counters(self): + telemetry_counter = await TelemetryCountersAsync.create() + assert(telemetry_counter._impressions_queued == 0) + assert(telemetry_counter._impressions_deduped == 0) + assert(telemetry_counter._impressions_dropped == 0) + assert(telemetry_counter._events_dropped == 0) + assert(telemetry_counter._events_queued == 0) + assert(telemetry_counter._auth_rejections == 0) + assert(telemetry_counter._token_refreshes == 0) + assert(telemetry_counter._update_from_sse == {}) + + await telemetry_counter.record_session_length(20) + assert(await telemetry_counter.get_session_length() == 20) + + [await telemetry_counter.record_auth_rejections() for i in range(5)] + auth_rejections = await telemetry_counter.pop_auth_rejections() + assert(telemetry_counter._auth_rejections == 0) + assert(auth_rejections == 5) + + [await telemetry_counter.record_token_refreshes() for i in range(3)] + token_refreshes = await telemetry_counter.pop_token_refreshes() + assert(telemetry_counter._token_refreshes == 0) + assert(token_refreshes == 3) + + await telemetry_counter.record_impressions_value(ModelTelemetry.CounterConstants.IMPRESSIONS_QUEUED, 10) + assert(telemetry_counter._impressions_queued == 10) + await telemetry_counter.record_impressions_value(ModelTelemetry.CounterConstants.IMPRESSIONS_DEDUPED, 14) + assert(telemetry_counter._impressions_deduped == 14) + await telemetry_counter.record_impressions_value(ModelTelemetry.CounterConstants.IMPRESSIONS_DROPPED, 2) + assert(telemetry_counter._impressions_dropped == 2) + await telemetry_counter.record_events_value(ModelTelemetry.CounterConstants.EVENTS_QUEUED, 30) + assert(telemetry_counter._events_queued == 30) + await telemetry_counter.record_events_value(ModelTelemetry.CounterConstants.EVENTS_DROPPED, 1) + assert(telemetry_counter._events_dropped == 1) + await telemetry_counter.record_update_from_sse(UpdateFromSSE.SPLIT_UPDATE) + assert(telemetry_counter._update_from_sse[UpdateFromSSE.SPLIT_UPDATE.value] == 1) + updates = await telemetry_counter.pop_update_from_sse(UpdateFromSSE.SPLIT_UPDATE) + assert(telemetry_counter._update_from_sse[UpdateFromSSE.SPLIT_UPDATE.value] == 0) + assert(updates == 1) + + @pytest.mark.asyncio + async def test_streaming_events(self, mocker): + streaming_events = await StreamingEventsAsync.create() + await streaming_events.record_streaming_event((ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED, 'split', 1234)) + await streaming_events.record_streaming_event((ModelTelemetry.StreamingEventTypes.STREAMING_STATUS, 'split', 1234)) + events = await streaming_events.pop_streaming_events() + assert(streaming_events._streaming_events == []) + assert(events == {'streamingEvents': [{'e': ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED.value, 'd': 'split', 't': 1234}, + {'e': ModelTelemetry.StreamingEventTypes.STREAMING_STATUS.value, 'd': 'split', 't': 1234}]}) + + @pytest.mark.asyncio + async def test_telemetry_config(self): + telemetry_config = await TelemetryConfigAsync.create() + config = {'operationMode': 'standalone', + 'streamingEnabled': True, + 'impressionsQueueSize': 100, + 'eventsQueueSize': 200, + 'impressionsMode': 'DEBUG','' + 'impressionListener': None, + 'featuresRefreshRate': 30, + 'segmentsRefreshRate': 30, + 'impressionsRefreshRate': 60, + 'eventsPushRate': 60, + 'metricsRefreshRate': 10, + 'storageType': None, + 'flagSetsFilter': None + } + await telemetry_config.record_config(config, {}, 5, 2) + assert(await telemetry_config.get_stats() == {'oM': 0, + 'sT': telemetry_config._get_storage_type(config['operationMode'], config['storageType']), + 'sE': config['streamingEnabled'], + 'rR': {'sp': 30, 'se': 30, 'im': 60, 'ev': 60, 'te': 10}, + 'uO': {'s': False, 'e': False, 'a': False, 'st': False, 't': False}, + 'iQ': config['impressionsQueueSize'], + 'eQ': config['eventsQueueSize'], + 'iM': telemetry_config._get_impressions_mode(config['impressionsMode']), + 'iL': True if config['impressionListener'] is not None else False, + 'hp': telemetry_config._check_if_proxy_detected(), + 'tR': 0, + 'nR': 0, + 'bT': 0, + 'aF': 0, + 'rF': 0, + 'fsT': 5, + 'fsI': 2} + ) + + await telemetry_config.record_ready_time(10) + assert(telemetry_config._time_until_ready == 10) + + [await telemetry_config.record_bur_time_out() for i in range(2)] + assert(await telemetry_config.get_bur_time_outs() == 2) + + [await telemetry_config.record_not_ready_usage() for i in range(5)] + assert(await telemetry_config.get_non_ready_usage() == 5) diff --git a/tests/models/test_token.py b/tests/models/test_token.py index 935de52b..35444f97 100644 --- a/tests/models/test_token.py +++ b/tests/models/test_token.py @@ -11,8 +11,12 @@ class TokenTests(object): def test_from_raw_false(self): """Test token model parsing.""" parsed = token.from_raw(self.raw_false) - assert parsed == None - + assert parsed.push_enabled == False + assert parsed.iat == None + assert parsed.channels == None + assert parsed.exp == None + assert parsed.token == None + raw_empty = { 'pushEnabled': True, 'token': '', @@ -21,7 +25,11 @@ def test_from_raw_false(self): def test_from_raw_empty(self): """Test token model parsing.""" parsed = token.from_raw(self.raw_empty) - assert parsed == None + assert parsed.push_enabled == False + assert parsed.iat == None + assert parsed.channels == None + assert parsed.exp == None + assert parsed.token == None raw_ok = { 'pushEnabled': True, @@ -39,4 +47,3 @@ def test_from_raw(self): assert parsed.channels['NzM2MDI5Mzc0_MTgyNTg1MTgwNg==_splits'] == ['subscribe'] assert parsed.channels['control_pri'] == ['subscribe', 'channel-metadata:publishers'] assert parsed.channels['control_sec'] == ['subscribe', 'channel-metadata:publishers'] - diff --git a/tests/push/test_manager.py b/tests/push/test_manager.py index d4b48bc1..3525baf3 100644 --- a/tests/push/test_manager.py +++ b/tests/push/test_manager.py @@ -2,19 +2,22 @@ #pylint:disable=no-self-use,protected-access from threading import Thread from queue import Queue +import pytest -from splitio.api.auth import APIException - +from splitio.api import APIException from splitio.models.token import Token - from splitio.push.sse import SSEEvent from splitio.push.parser import parse_incoming_event, EventType, ControlType, ControlMessage, \ OccupancyMessage, SplitChangeUpdate, SplitKillUpdate, SegmentChangeUpdate -from splitio.push.processor import MessageProcessor +from splitio.push.processor import MessageProcessor, MessageProcessorAsync from splitio.push.status_tracker import PushStatusTracker -from splitio.push.manager import PushManager, _TOKEN_REFRESH_GRACE_PERIOD -from splitio.push.splitsse import SplitSSEClient +from splitio.push.manager import PushManager, PushManagerAsync, _TOKEN_REFRESH_GRACE_PERIOD +from splitio.push.splitsse import SplitSSEClient, SplitSSEClientAsync from splitio.push.status_tracker import Status +from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemoryTelemetryStorageAsync +from splitio.models.telemetry import StreamingEventTypes +from splitio.optional.loaders import asyncio from tests.helpers import Any @@ -34,12 +37,14 @@ def test_connection_success(self, mocker): mocker.patch('splitio.push.manager.Timer', new=timer_mock) mocker.patch('splitio.push.manager.SplitSSEClient', new=sse_constructor_mock) feedback_loop = Queue() - manager = PushManager(api_mock, mocker.Mock(), feedback_loop, mocker.Mock()) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + manager = PushManager(api_mock, mocker.Mock(), feedback_loop, mocker.Mock(), telemetry_runtime_producer) def new_start(*args, **kwargs): # pylint: disable=unused-argument """splitsse.start mock.""" - thread = Thread(target=manager._handle_connection_ready) - thread.setDaemon(True) + thread = Thread(target=manager._handle_connection_ready, daemon=True) thread.start() return True @@ -54,6 +59,8 @@ def new_start(*args, **kwargs): # pylint: disable=unused-argument mocker.call().setName('TokenRefresh'), mocker.call().start() ] + assert(telemetry_storage._streaming_events._streaming_events[0]._type == StreamingEventTypes.TOKEN_REFRESH.value) + assert(telemetry_storage._streaming_events._streaming_events[1]._type == StreamingEventTypes.CONNECTION_ESTABLISHED.value) def test_connection_failure(self, mocker): """Test the connection fails to be established.""" @@ -67,12 +74,14 @@ def test_connection_failure(self, mocker): mocker.patch('splitio.push.manager.Timer', new=timer_mock) mocker.patch('splitio.push.manager.SplitSSEClient', new=sse_constructor_mock) feedback_loop = Queue() - manager = PushManager(api_mock, mocker.Mock(), feedback_loop, mocker.Mock()) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + manager = PushManager(api_mock, mocker.Mock(), feedback_loop, mocker.Mock(), telemetry_runtime_producer) def new_start(*args, **kwargs): # pylint: disable=unused-argument """splitsse.start mock.""" - thread = Thread(target=manager._handle_connection_end) - thread.setDaemon(True) + thread = Thread(target=manager._handle_connection_end, daemon=True) thread.start() return False @@ -82,6 +91,28 @@ def new_start(*args, **kwargs): # pylint: disable=unused-argument assert feedback_loop.get() == Status.PUSH_RETRYABLE_ERROR assert timer_mock.mock_calls == [mocker.call(0, Any())] + def test_empty_auth_respnse(self, mocker): + """Test the initial status is ok and reset() works as expected.""" + api_mock = mocker.Mock() + api_mock.authenticate.return_value = Token(False, None, None, None, None) + + sse_mock = mocker.Mock(spec=SplitSSEClient) + sse_constructor_mock = mocker.Mock() + sse_constructor_mock.return_value = sse_mock + timer_mock = mocker.Mock() + mocker.patch('splitio.push.manager.Timer', new=timer_mock) + mocker.patch('splitio.push.manager.SplitSSEClient', new=sse_constructor_mock) + feedback_loop = Queue() + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + manager = PushManager(api_mock, mocker.Mock(), feedback_loop, mocker.Mock(), telemetry_runtime_producer) + manager.start() + assert feedback_loop.get() == Status.PUSH_NONRETRYABLE_ERROR + assert timer_mock.mock_calls == [mocker.call(0, Any())] + assert sse_mock.mock_calls == [] + + def test_push_disabled(self, mocker): """Test the initial status is ok and reset() works as expected.""" api_mock = mocker.Mock() @@ -94,12 +125,16 @@ def test_push_disabled(self, mocker): mocker.patch('splitio.push.manager.Timer', new=timer_mock) mocker.patch('splitio.push.manager.SplitSSEClient', new=sse_constructor_mock) feedback_loop = Queue() - manager = PushManager(api_mock, mocker.Mock(), feedback_loop, mocker.Mock()) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + manager = PushManager(api_mock, mocker.Mock(), feedback_loop, mocker.Mock(), telemetry_runtime_producer) manager.start() assert feedback_loop.get() == Status.PUSH_NONRETRYABLE_ERROR assert timer_mock.mock_calls == [mocker.call(0, Any())] assert sse_mock.mock_calls == [] + def test_auth_apiexception(self, mocker): """Test the initial status is ok and reset() works as expected.""" api_mock = mocker.Mock() @@ -113,7 +148,10 @@ def test_auth_apiexception(self, mocker): mocker.patch('splitio.push.manager.SplitSSEClient', new=sse_constructor_mock) feedback_loop = Queue() - manager = PushManager(api_mock, mocker.Mock(), feedback_loop, mocker.Mock()) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + manager = PushManager(api_mock, mocker.Mock(), feedback_loop, mocker.Mock(), telemetry_runtime_producer) manager.start() assert feedback_loop.get() == Status.PUSH_RETRYABLE_ERROR assert timer_mock.mock_calls == [mocker.call(0, Any())] @@ -122,7 +160,7 @@ def test_auth_apiexception(self, mocker): def test_split_change(self, mocker): """Test update-type messages are properly forwarded to the processor.""" sse_event = SSEEvent('1', EventType.MESSAGE, '', '{}') - update_message = SplitChangeUpdate('chan', 123, 456) + update_message = SplitChangeUpdate('chan', 123, 456, None, None, None) parse_event_mock = mocker.Mock(spec=parse_incoming_event) parse_event_mock.return_value = update_message mocker.patch('splitio.push.manager.parse_incoming_event', new=parse_event_mock) @@ -130,11 +168,13 @@ def test_split_change(self, mocker): processor_mock = mocker.Mock(spec=MessageProcessor) mocker.patch('splitio.push.manager.MessageProcessor', new=processor_mock) - manager = PushManager(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + telemetry_runtime_producer = mocker.Mock() + synchronizer = mocker.Mock() + manager = PushManager(mocker.Mock(), synchronizer, mocker.Mock(), mocker.Mock(), telemetry_runtime_producer) manager._event_handler(sse_event) assert parse_event_mock.mock_calls == [mocker.call(sse_event)] assert processor_mock.mock_calls == [ - mocker.call(Any()), + mocker.call(synchronizer, telemetry_runtime_producer), mocker.call().handle(update_message) ] @@ -149,11 +189,13 @@ def test_split_kill(self, mocker): processor_mock = mocker.Mock(spec=MessageProcessor) mocker.patch('splitio.push.manager.MessageProcessor', new=processor_mock) - manager = PushManager(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + telemetry_runtime_producer = mocker.Mock() + synchronizer = mocker.Mock() + manager = PushManager(mocker.Mock(), synchronizer, mocker.Mock(), mocker.Mock(), telemetry_runtime_producer) manager._event_handler(sse_event) assert parse_event_mock.mock_calls == [mocker.call(sse_event)] assert processor_mock.mock_calls == [ - mocker.call(Any()), + mocker.call(synchronizer, telemetry_runtime_producer), mocker.call().handle(update_message) ] @@ -168,11 +210,13 @@ def test_segment_change(self, mocker): processor_mock = mocker.Mock(spec=MessageProcessor) mocker.patch('splitio.push.manager.MessageProcessor', new=processor_mock) - manager = PushManager(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + telemetry_runtime_producer = mocker.Mock() + synchronizer = mocker.Mock() + manager = PushManager(mocker.Mock(), synchronizer, mocker.Mock(), mocker.Mock(), telemetry_runtime_producer) manager._event_handler(sse_event) assert parse_event_mock.mock_calls == [mocker.call(sse_event)] assert processor_mock.mock_calls == [ - mocker.call(Any()), + mocker.call(synchronizer, telemetry_runtime_producer), mocker.call().handle(update_message) ] @@ -187,13 +231,10 @@ def test_control_message(self, mocker): status_tracker_mock = mocker.Mock(spec=PushStatusTracker) mocker.patch('splitio.push.manager.PushStatusTracker', new=status_tracker_mock) - manager = PushManager(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + manager = PushManager(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) manager._event_handler(sse_event) assert parse_event_mock.mock_calls == [mocker.call(sse_event)] - assert status_tracker_mock.mock_calls == [ - mocker.call(), - mocker.call().handle_control_message(control_message) - ] + assert status_tracker_mock.mock_calls[1] == mocker.call().handle_control_message(control_message) def test_occupancy_message(self, mocker): """Test control mesage is forwarded to status tracker.""" @@ -206,10 +247,247 @@ def test_occupancy_message(self, mocker): status_tracker_mock = mocker.Mock(spec=PushStatusTracker) mocker.patch('splitio.push.manager.PushStatusTracker', new=status_tracker_mock) - manager = PushManager(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + manager = PushManager(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) manager._event_handler(sse_event) assert parse_event_mock.mock_calls == [mocker.call(sse_event)] - assert status_tracker_mock.mock_calls == [ - mocker.call(), - mocker.call().handle_occupancy(occupancy_message) + assert status_tracker_mock.mock_calls[1] == mocker.call().handle_occupancy(occupancy_message) + +class PushManagerAsyncTests(object): + """Parser tests.""" + + @pytest.mark.asyncio + async def test_connection_success(self, mocker): + """Test the initial status is ok and reset() works as expected.""" + api_mock = mocker.Mock() + async def authenticate(): + return Token(True, 'abc', {}, 2000000, 1000000) + api_mock.authenticate.side_effect = authenticate + + self.token = None + def timer_mock(token): + print("timer_mock") + self.token = token + return (token.exp - token.iat) - _TOKEN_REFRESH_GRACE_PERIOD + + async def coro(): + t = 0 + try: + while t < 3: + await asyncio.sleep(1) + yield SSEEvent('1', EventType.MESSAGE, '', '{}') + t += 1 + except Exception: + pass + + sse_mock = mocker.Mock(spec=SplitSSEClientAsync) + sse_mock.start.return_value = coro() + async def stop(): + pass + sse_mock.stop = stop + + feedback_loop = asyncio.Queue() + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + manager = PushManagerAsync(api_mock, mocker.Mock(), feedback_loop, mocker.Mock(), telemetry_runtime_producer) + manager._get_time_period = timer_mock + manager._sse_client = sse_mock + + async def deferred_shutdown(): + await asyncio.sleep(2) + await manager.stop(True) + + manager.start() + sse_mock.status = SplitSSEClient._Status.IDLE + shutdown_task = asyncio.get_running_loop().create_task(deferred_shutdown()) + + assert await feedback_loop.get() == Status.PUSH_SUBSYSTEM_UP + assert self.token.push_enabled + assert self.token.token == 'abc' + assert self.token.channels == {} + assert self.token.exp == 2000000 + assert self.token.iat == 1000000 + + try: + await shutdown_task + except: + pass + assert not manager._running + assert(telemetry_storage._streaming_events._streaming_events[0]._type == StreamingEventTypes.TOKEN_REFRESH.value) + assert(telemetry_storage._streaming_events._streaming_events[1]._type == StreamingEventTypes.CONNECTION_ESTABLISHED.value) + + @pytest.mark.asyncio + async def test_connection_failure(self, mocker): + """Test the connection fails to be established.""" + api_mock = mocker.Mock() + async def authenticate(): + return Token(True, 'abc', {}, 2000000, 1000000) + api_mock.authenticate.side_effect = authenticate + + sse_mock = mocker.Mock(spec=SplitSSEClientAsync) + feedback_loop = asyncio.Queue() + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + manager = PushManagerAsync(api_mock, mocker.Mock(), feedback_loop, mocker.Mock(), telemetry_runtime_producer) + manager._sse_client = sse_mock + + async def coro(): + if False: yield '' # fit a never-called yield directive to force the func to be an async generator + return + sse_mock.start.return_value = coro() + + manager.start() + assert await feedback_loop.get() == Status.PUSH_RETRYABLE_ERROR + + await manager.stop(True) + assert not manager._running + + @pytest.mark.asyncio + async def test_push_disabled(self, mocker): + """Test the initial status is ok and reset() works as expected.""" + api_mock = mocker.Mock() + async def authenticate(): + return Token(False, 'abc', {}, 1, 2) + api_mock.authenticate.side_effect = authenticate + + sse_mock = mocker.Mock(spec=SplitSSEClientAsync) + feedback_loop = asyncio.Queue() + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + + manager = PushManagerAsync(api_mock, mocker.Mock(), feedback_loop, mocker.Mock(), telemetry_runtime_producer) + manager._sse_client = sse_mock + + manager.start() + assert await feedback_loop.get() == Status.PUSH_NONRETRYABLE_ERROR + assert sse_mock.mock_calls == [] + + await manager.stop(True) + assert not manager._running + + @pytest.mark.asyncio + async def test_auth_apiexception(self, mocker): + """Test the initial status is ok and reset() works as expected.""" + api_mock = mocker.Mock() + api_mock.authenticate.side_effect = APIException('something') + + sse_mock = mocker.Mock(spec=SplitSSEClientAsync) + + feedback_loop = asyncio.Queue() + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + manager = PushManagerAsync(api_mock, mocker.Mock(), feedback_loop, mocker.Mock(), telemetry_runtime_producer) + manager._sse_client = sse_mock + manager.start() + assert await feedback_loop.get() == Status.PUSH_RETRYABLE_ERROR + assert sse_mock.mock_calls == [] + + await manager.stop(True) + assert not manager._running + + @pytest.mark.asyncio + async def test_split_change(self, mocker): + """Test update-type messages are properly forwarded to the processor.""" + sse_event = SSEEvent('1', EventType.MESSAGE, '', '{}') + update_message = SplitChangeUpdate('chan', 123, 456, None, None, None) + parse_event_mock = mocker.Mock(spec=parse_incoming_event) + parse_event_mock.return_value = update_message + mocker.patch('splitio.push.manager.parse_incoming_event', new=parse_event_mock) + + processor_mock = mocker.Mock(spec=MessageProcessorAsync) + mocker.patch('splitio.push.manager.MessageProcessorAsync', new=processor_mock) + + telemetry_runtime_producer = mocker.Mock() + synchronizer = mocker.Mock() + manager = PushManagerAsync(mocker.Mock(), synchronizer, mocker.Mock(), mocker.Mock(), telemetry_runtime_producer) + await manager._event_handler(sse_event) + assert parse_event_mock.mock_calls == [mocker.call(sse_event)] + assert processor_mock.mock_calls == [ + mocker.call(synchronizer, telemetry_runtime_producer), + mocker.call().handle(update_message) + ] + + @pytest.mark.asyncio + async def test_split_kill(self, mocker): + """Test update-type messages are properly forwarded to the processor.""" + sse_event = SSEEvent('1', EventType.MESSAGE, '', '{}') + update_message = SplitKillUpdate('chan', 123, 456, 'some_split', 'off') + parse_event_mock = mocker.Mock(spec=parse_incoming_event) + parse_event_mock.return_value = update_message + mocker.patch('splitio.push.manager.parse_incoming_event', new=parse_event_mock) + + processor_mock = mocker.Mock(spec=MessageProcessorAsync) + mocker.patch('splitio.push.manager.MessageProcessorAsync', new=processor_mock) + + telemetry_runtime_producer = mocker.Mock() + synchronizer = mocker.Mock() + manager = PushManagerAsync(mocker.Mock(), synchronizer, mocker.Mock(), mocker.Mock(), telemetry_runtime_producer) + await manager._event_handler(sse_event) + assert parse_event_mock.mock_calls == [mocker.call(sse_event)] + assert processor_mock.mock_calls == [ + mocker.call(synchronizer, telemetry_runtime_producer), + mocker.call().handle(update_message) ] + + await manager.stop(True) + assert not manager._running + + @pytest.mark.asyncio + async def test_segment_change(self, mocker): + """Test update-type messages are properly forwarded to the processor.""" + sse_event = SSEEvent('1', EventType.MESSAGE, '', '{}') + update_message = SegmentChangeUpdate('chan', 123, 456, 'some_segment') + parse_event_mock = mocker.Mock(spec=parse_incoming_event) + parse_event_mock.return_value = update_message + mocker.patch('splitio.push.manager.parse_incoming_event', new=parse_event_mock) + + processor_mock = mocker.Mock(spec=MessageProcessorAsync) + mocker.patch('splitio.push.manager.MessageProcessorAsync', new=processor_mock) + + telemetry_runtime_producer = mocker.Mock() + synchronizer = mocker.Mock() + manager = PushManagerAsync(mocker.Mock(), synchronizer, mocker.Mock(), mocker.Mock(), telemetry_runtime_producer) + await manager._event_handler(sse_event) + assert parse_event_mock.mock_calls == [mocker.call(sse_event)] + assert processor_mock.mock_calls == [ + mocker.call(synchronizer, telemetry_runtime_producer), + mocker.call().handle(update_message) + ] + + @pytest.mark.asyncio + async def test_control_message(self, mocker): + """Test control mesage is forwarded to status tracker.""" + sse_event = SSEEvent('1', EventType.MESSAGE, '', '{}') + control_message = ControlMessage('chan', 123, ControlType.STREAMING_ENABLED) + parse_event_mock = mocker.Mock(spec=parse_incoming_event) + parse_event_mock.return_value = control_message + mocker.patch('splitio.push.manager.parse_incoming_event', new=parse_event_mock) + + status_tracker_mock = mocker.Mock(spec=PushStatusTracker) + mocker.patch('splitio.push.manager.PushStatusTrackerAsync', new=status_tracker_mock) + + manager = PushManagerAsync(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + await manager._event_handler(sse_event) + assert parse_event_mock.mock_calls == [mocker.call(sse_event)] + assert status_tracker_mock.mock_calls[1] == mocker.call().handle_control_message(control_message) + + @pytest.mark.asyncio + async def test_occupancy_message(self, mocker): + """Test control mesage is forwarded to status tracker.""" + sse_event = SSEEvent('1', EventType.MESSAGE, '', '{}') + occupancy_message = OccupancyMessage('[?occupancy=metrics.publishers]control_pri', 123, 2) + parse_event_mock = mocker.Mock(spec=parse_incoming_event) + parse_event_mock.return_value = occupancy_message + mocker.patch('splitio.push.manager.parse_incoming_event', new=parse_event_mock) + + status_tracker_mock = mocker.Mock(spec=PushStatusTracker) + mocker.patch('splitio.push.manager.PushStatusTrackerAsync', new=status_tracker_mock) + + manager = PushManagerAsync(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + await manager._event_handler(sse_event) + assert parse_event_mock.mock_calls == [mocker.call(sse_event)] + assert status_tracker_mock.mock_calls[1] == mocker.call().handle_occupancy(occupancy_message) diff --git a/tests/push/test_parser.py b/tests/push/test_parser.py index 0367f84b..faffb3d0 100644 --- a/tests/push/test_parser.py +++ b/tests/push/test_parser.py @@ -55,7 +55,18 @@ def test_event_parsing(self): assert isinstance(parsed0, SplitKillUpdate) assert parsed0.default_treatment == 'some' assert parsed0.change_number == 1591996754396 - assert parsed0.split_name == 'test' + assert parsed0.feature_flag_name == 'test' + + e1 = make_message( + 'NDA5ODc2MTAyNg==_MzAyODY0NDkyOA==_splits', + {'type':'SPLIT_UPDATE','changeNumber':1591996685190, 'pcn': 12, 'c': 2, 'd': 'eJzEUtFu2kAQ/BU0z4d0hw2Be0MFRVGJIx'}, + ) + parsed1 = parse_incoming_event(e1) + assert isinstance(parsed1, SplitChangeUpdate) + assert parsed1.change_number == 1591996685190 + assert parsed1.previous_change_number == 12 + assert parsed1.compression == 2 + assert parsed1.object_definition == 'eJzEUtFu2kAQ/BU0z4d0hw2Be0MFRVGJIx' e1 = make_message( 'NDA5ODc2MTAyNg==_MzAyODY0NDkyOA==_splits', @@ -64,6 +75,9 @@ def test_event_parsing(self): parsed1 = parse_incoming_event(e1) assert isinstance(parsed1, SplitChangeUpdate) assert parsed1.change_number == 1591996685190 + assert parsed1.previous_change_number == None + assert parsed1.compression == None + assert parsed1.object_definition == None e2 = make_message( 'NDA5ODc2MTAyNg==_MzAyODY0NDkyOA==_segments', diff --git a/tests/push/test_processor.py b/tests/push/test_processor.py index aa6cf52f..673a1917 100644 --- a/tests/push/test_processor.py +++ b/tests/push/test_processor.py @@ -1,8 +1,11 @@ """Message processor tests.""" from queue import Queue -from splitio.push.processor import MessageProcessor -from splitio.sync.synchronizer import Synchronizer +import pytest + +from splitio.push.processor import MessageProcessor, MessageProcessorAsync +from splitio.sync.synchronizer import Synchronizer, SynchronizerAsync from splitio.push.parser import SplitChangeUpdate, SegmentChangeUpdate, SplitKillUpdate +from splitio.optional.loaders import asyncio class ProcessorTests(object): @@ -13,8 +16,8 @@ def test_split_change(self, mocker): sync_mock = mocker.Mock(spec=Synchronizer) queue_mock = mocker.Mock(spec=Queue) mocker.patch('splitio.push.processor.Queue', new=queue_mock) - processor = MessageProcessor(sync_mock) - update = SplitChangeUpdate('sarasa', 123, 123) + processor = MessageProcessor(sync_mock, mocker.Mock()) + update = SplitChangeUpdate('sarasa', 123, 123, None, None, None) processor.handle(update) assert queue_mock.mock_calls == [ mocker.call(), # construction of split queue @@ -27,7 +30,7 @@ def test_split_kill(self, mocker): sync_mock = mocker.Mock(spec=Synchronizer) queue_mock = mocker.Mock(spec=Queue) mocker.patch('splitio.push.processor.Queue', new=queue_mock) - processor = MessageProcessor(sync_mock) + processor = MessageProcessor(sync_mock, mocker.Mock()) update = SplitKillUpdate('sarasa', 123, 456, 'some_split', 'off') processor.handle(update) assert queue_mock.mock_calls == [ @@ -44,7 +47,7 @@ def test_segment_change(self, mocker): sync_mock = mocker.Mock(spec=Synchronizer) queue_mock = mocker.Mock(spec=Queue) mocker.patch('splitio.push.processor.Queue', new=queue_mock) - processor = MessageProcessor(sync_mock) + processor = MessageProcessor(sync_mock, mocker.Mock()) update = SegmentChangeUpdate('sarasa', 123, 123, 'some_segment') processor.handle(update) assert queue_mock.mock_calls == [ @@ -56,3 +59,59 @@ def test_segment_change(self, mocker): def test_todo(self): """Fix previous tests so that we validate WHICH queue the update is pushed into.""" assert NotImplementedError("DO THAT") + +class ProcessorAsyncTests(object): + """Message processor test cases.""" + + @pytest.mark.asyncio + async def test_split_change(self, mocker): + """Test split change is properly handled.""" + sync_mock = mocker.Mock(spec=Synchronizer) + self._update = None + async def put_mock(first, event): + self._update = event + + mocker.patch('splitio.push.processor.asyncio.Queue.put', new=put_mock) + processor = MessageProcessorAsync(sync_mock, mocker.Mock()) + update = SplitChangeUpdate('sarasa', 123, 123, None, None, None) + await processor.handle(update) + assert update == self._update + + @pytest.mark.asyncio + async def test_split_kill(self, mocker): + """Test split kill is properly handled.""" + + self._killed_split = None + async def kill_mock(split_name, default_treatment, change_number): + self._killed_split = (split_name, default_treatment, change_number) + + sync_mock = mocker.Mock(spec=SynchronizerAsync) + sync_mock.kill_split = kill_mock + + self._update = None + async def put_mock(first, event): + self._update = event + + mocker.patch('splitio.push.processor.asyncio.Queue.put', new=put_mock) + processor = MessageProcessorAsync(sync_mock, mocker.Mock()) + update = SplitKillUpdate('sarasa', 123, 456, 'some_split', 'off') + await processor.handle(update) + assert update == self._update + assert ('some_split', 'off', 456) == self._killed_split + + @pytest.mark.asyncio + async def test_segment_change(self, mocker): + """Test segment change is properly handled.""" + + sync_mock = mocker.Mock(spec=SynchronizerAsync) + queue_mock = mocker.Mock(spec=asyncio.Queue) + + self._update = None + async def put_mock(first, event): + self._update = event + + mocker.patch('splitio.push.processor.asyncio.Queue.put', new=put_mock) + processor = MessageProcessorAsync(sync_mock, mocker.Mock()) + update = SegmentChangeUpdate('sarasa', 123, 123, 'some_segment') + await processor.handle(update) + assert update == self._update diff --git a/tests/push/test_segment_worker.py b/tests/push/test_segment_worker.py index 9183c2dd..0a99f466 100644 --- a/tests/push/test_segment_worker.py +++ b/tests/push/test_segment_worker.py @@ -4,8 +4,9 @@ import pytest from splitio.api import APIException -from splitio.push.segmentworker import SegmentWorker +from splitio.push.workers import SegmentWorker, SegmentWorkerAsync from splitio.models.notification import SegmentChangeNotification +from splitio.optional.loaders import asyncio change_number_received = None segment_name_received = None @@ -58,3 +59,55 @@ def test_handler(self): segment_worker.stop() assert not segment_worker.is_running() + +class SegmentWorkerAsyncTests(object): + + @pytest.mark.asyncio + async def test_on_error(self): + q = asyncio.Queue() + + def handler_sync(change_number): + raise APIException('some') + + segment_worker = SegmentWorkerAsync(handler_sync, q) + segment_worker.start() + assert segment_worker.is_running() + + await q.put(SegmentChangeNotification('some', 'SEGMENT_UPDATE', 123456789, 'some')) + + with pytest.raises(Exception): + segment_worker._handler() + + assert segment_worker.is_running() + assert(self._worker_running()) + await segment_worker.stop() + await asyncio.sleep(.1) + assert not segment_worker.is_running() + assert(not self._worker_running()) + + def _worker_running(self): + worker_running = False + for task in asyncio.all_tasks(): + if task._coro.cr_code.co_name == '_run' and not task.done(): + worker_running = True + break + return worker_running + + @pytest.mark.asyncio + async def test_handler(self): + q = asyncio.Queue() + segment_worker = SegmentWorkerAsync(handler_sync, q) + global change_number_received + assert not segment_worker.is_running() + segment_worker.start() + assert segment_worker.is_running() + + await q.put(SegmentChangeNotification('some', 'SEGMENT_UPDATE', 123456789, 'some')) + + await asyncio.sleep(.1) + assert change_number_received == 123456789 + assert segment_name_received == 'some' + + await segment_worker.stop() + await asyncio.sleep(.1) + assert(not self._worker_running()) diff --git a/tests/push/test_split_worker.py b/tests/push/test_split_worker.py index 23fa7060..28b5408d 100644 --- a/tests/push/test_split_worker.py +++ b/tests/push/test_split_worker.py @@ -1,34 +1,157 @@ """Split Worker tests.""" import time import queue +import base64 import pytest from splitio.api import APIException -from splitio.push.splitworker import SplitWorker +from splitio.push.workers import SplitWorker, SplitWorkerAsync from splitio.models.notification import SplitChangeNotification +from splitio.optional.loaders import asyncio +from splitio.push.parser import SplitChangeUpdate, RBSChangeUpdate +from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemorySplitStorage, InMemorySegmentStorage, \ + InMemoryTelemetryStorageAsync, InMemorySplitStorageAsync, InMemorySegmentStorageAsync change_number_received = None +rbs = { + "changeNumber": 5, + "name": "sample_rule_based_segment", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded":{ + "keys":["mauro@split.io","gaston@split.io"], + "segments":[] + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user", + "attribute": "email" + }, + "matcherType": "ENDS_WITH", + "negate": False, + "whitelistMatcherData": { + "whitelist": [ + "@split.io" + ] + } + } + ] + } + } + ] + } +def handler_sync(change_number, rbs_change_number): + global change_number_received + global rbs_change_number_received + + change_number_received = change_number + rbs_change_number_received = rbs_change_number + return -def handler_sync(change_number): +async def handler_async(change_number, rbs_change_number): global change_number_received + global rbs_change_number_received change_number_received = change_number + rbs_change_number_received = rbs_change_number return class SplitWorkerTests(object): - def test_on_error(self): + def test_handler(self, mocker): q = queue.Queue() + split_worker = SplitWorker(handler_sync, mocker.Mock(), q, mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + + global change_number_received + global rbs_change_number_received + assert not split_worker.is_running() + split_worker.start() + assert split_worker.is_running() + + def get_change_number(): + return 2345 + split_worker._feature_flag_storage.get_change_number = get_change_number + + def get_rbs_change_number(): + return 2345 + split_worker._rule_based_segment_storage.get_change_number = get_rbs_change_number + + self._feature_flag_added = None + self._feature_flag_deleted = None + def update(feature_flag_add, feature_flag_delete, change_number): + self._feature_flag_added = feature_flag_add + self._feature_flag_deleted = feature_flag_delete + split_worker._feature_flag_storage.update = update + split_worker._feature_flag_storage.config_flag_sets_used = 0 + + self._rbs_added = None + self._rbs_deleted = None + def update(rbs_add, rbs_delete, change_number): + self._rbs_added = rbs_add + self._rbs_deleted = rbs_delete + split_worker._rule_based_segment_storage.update = update + + # should not call the handler + rbs_change_number_received = 0 + rbs1 = str(rbs) + rbs1 = rbs1.replace("'", "\"") + rbs1 = rbs1.replace("False", "false") + encoded = base64.b64encode(bytes(rbs1, "utf-8")) + q.put(RBSChangeUpdate('some', 'RB_SEGMENT_UPDATE', 123456790, 2345, encoded, 0)) + time.sleep(0.1) + assert rbs_change_number_received == 0 + assert self._rbs_added[0].name == "sample_rule_based_segment" + + # should call the handler + q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456789, None, None, None)) + time.sleep(0.1) + assert change_number_received == 123456789 + assert rbs_change_number_received == None + # should call the handler + q.put(RBSChangeUpdate('some', 'RB_SEGMENT_UPDATE', 123456789, None, None, None)) + time.sleep(0.1) + assert rbs_change_number_received == 123456789 + assert change_number_received == None + + + # should call the handler + q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456790, 12345, "{}", 1)) + time.sleep(0.1) + assert change_number_received == 123456790 + + # should call the handler + change_number_received = 0 + q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456790, 12345, "{}", 3)) + time.sleep(0.1) + assert change_number_received == 123456790 + + # should Not call the handler + change_number_received = 0 + q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456, 2345, "eJzEUtFq20AQ/JUwz2c4WZZr3ZupTQh1FKjcQinGrKU95cjpZE6nh9To34ssJ3FNX0sfd3Zm53b2TgietDbF9vXIGdUMha5lDwFTQiGOmTQlchLRPJlEEZeTVJZ6oimWZTpP5WyWQMCNyoOxZPft0ZoA8TZ5aW1TUDCNg4qk/AueM5dQkyiez6IonS6mAu0IzWWSxovFLBZoA4WuhcLy8/bh+xoCL8bagaXJtixQsqbOhq1nCjW7AIVGawgUz+Qqzrr6wB4qmi9m00/JIk7TZCpAtmqgpgJF47SpOn9+UQt16s9YaS71z9NHOYQFha9Pm83Tty0EagrFM/t733RHqIFZH4wb7LDMVh+Ecc4Lv+ZsuQiNH8hXF3hLv39XXNCHbJ+v7x/X2eDmuKLA74sPihVr47jMuRpWfxy1Kwo0GLQjmv1xpBFD3+96gSP5cLVouM7QQaA1vxhK9uKmd853bEZS9jsBSwe2UDDu7mJxd2Mo/muQy81m/2X9I7+N8R/FcPmUd76zjH7X/w4AAP//90glTw==", 2)) + time.sleep(0.1) + assert change_number_received == 0 + + split_worker.stop() + assert not split_worker.is_running() + + def test_on_error(self, mocker): + q = queue.Queue() def handler_sync(change_number): raise APIException('some') - split_worker = SplitWorker(handler_sync, q) + split_worker = SplitWorker(handler_sync, mocker.Mock(), q, mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) split_worker.start() assert split_worker.is_running() - q.put(SplitChangeNotification('some', 'SPLIT_UPDATE', 123456789)) + q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456789, None, None, None)) with pytest.raises(Exception): split_worker._handler() @@ -39,19 +162,388 @@ def handler_sync(change_number): assert not split_worker.is_running() assert not split_worker._worker.is_alive() - def test_handler(self): + def test_compression(self, mocker): q = queue.Queue() - split_worker = SplitWorker(handler_sync, q) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + split_worker = SplitWorker(handler_sync, mocker.Mock(), q, mocker.Mock(), mocker.Mock(), telemetry_runtime_producer, mocker.Mock()) + global change_number_received + split_worker.start() + def get_change_number(): + return 2345 + split_worker._feature_flag_storage.get_change_number = get_change_number + + self._feature_flag_added = None + self._feature_flag_deleted = None + def update(feature_flag_add, feature_flag_delete, change_number): + self._feature_flag_added = feature_flag_add + self._feature_flag_deleted = feature_flag_delete + split_worker._feature_flag_storage.update = update + split_worker._feature_flag_storage.config_flag_sets_used = 0 + # compression 0 + self._feature_flag_added = None + self._feature_flag_deleted = None + q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456790, 2345, 'eyJ0cmFmZmljVHlwZU5hbWUiOiJ1c2VyIiwiaWQiOiIzM2VhZmE1MC0xYTY1LTExZWQtOTBkZi1mYTMwZDk2OTA0NDUiLCJuYW1lIjoiYmlsYWxfc3BsaXQiLCJ0cmFmZmljQWxsb2NhdGlvbiI6MTAwLCJ0cmFmZmljQWxsb2NhdGlvblNlZWQiOi0xMzY0MTE5MjgyLCJzZWVkIjotNjA1OTM4ODQzLCJzdGF0dXMiOiJBQ1RJVkUiLCJraWxsZWQiOmZhbHNlLCJkZWZhdWx0VHJlYXRtZW50Ijoib2ZmIiwiY2hhbmdlTnVtYmVyIjoxNjg0MzQwOTA4NDc1LCJhbGdvIjoyLCJjb25maWd1cmF0aW9ucyI6e30sImNvbmRpdGlvbnMiOlt7ImNvbmRpdGlvblR5cGUiOiJST0xMT1VUIiwibWF0Y2hlckdyb3VwIjp7ImNvbWJpbmVyIjoiQU5EIiwibWF0Y2hlcnMiOlt7ImtleVNlbGVjdG9yIjp7InRyYWZmaWNUeXBlIjoidXNlciJ9LCJtYXRjaGVyVHlwZSI6IklOX1NFR01FTlQiLCJuZWdhdGUiOmZhbHNlLCJ1c2VyRGVmaW5lZFNlZ21lbnRNYXRjaGVyRGF0YSI6eyJzZWdtZW50TmFtZSI6ImJpbGFsX3NlZ21lbnQifX1dfSwicGFydGl0aW9ucyI6W3sidHJlYXRtZW50Ijoib24iLCJzaXplIjowfSx7InRyZWF0bWVudCI6Im9mZiIsInNpemUiOjEwMH1dLCJsYWJlbCI6ImluIHNlZ21lbnQgYmlsYWxfc2VnbWVudCJ9LHsiY29uZGl0aW9uVHlwZSI6IlJPTExPVVQiLCJtYXRjaGVyR3JvdXAiOnsiY29tYmluZXIiOiJBTkQiLCJtYXRjaGVycyI6W3sia2V5U2VsZWN0b3IiOnsidHJhZmZpY1R5cGUiOiJ1c2VyIn0sIm1hdGNoZXJUeXBlIjoiQUxMX0tFWVMiLCJuZWdhdGUiOmZhbHNlfV19LCJwYXJ0aXRpb25zIjpbeyJ0cmVhdG1lbnQiOiJvbiIsInNpemUiOjB9LHsidHJlYXRtZW50Ijoib2ZmIiwic2l6ZSI6MTAwfV0sImxhYmVsIjoiZGVmYXVsdCBydWxlIn1dfQ==', 0)) + time.sleep(0.1) + assert self._feature_flag_added[0].name == 'bilal_split' + assert telemetry_storage._counters._update_from_sse['sp'] == 1 + + # compression 2 + self._feature_flag_added = None + self._feature_flag_deleted = None + q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456790, 2345, 'eJzEUtFq20AQ/JUwz2c4WZZr3ZupTQh1FKjcQinGrKU95cjpZE6nh9To34ssJ3FNX0sfd3Zm53b2TgietDbF9vXIGdUMha5lDwFTQiGOmTQlchLRPJlEEZeTVJZ6oimWZTpP5WyWQMCNyoOxZPft0ZoA8TZ5aW1TUDCNg4qk/AueM5dQkyiez6IonS6mAu0IzWWSxovFLBZoA4WuhcLy8/bh+xoCL8bagaXJtixQsqbOhq1nCjW7AIVGawgUz+Qqzrr6wB4qmi9m00/JIk7TZCpAtmqgpgJF47SpOn9+UQt16s9YaS71z9NHOYQFha9Pm83Tty0EagrFM/t733RHqIFZH4wb7LDMVh+Ecc4Lv+ZsuQiNH8hXF3hLv39XXNCHbJ+v7x/X2eDmuKLA74sPihVr47jMuRpWfxy1Kwo0GLQjmv1xpBFD3+96gSP5cLVouM7QQaA1vxhK9uKmd853bEZS9jsBSwe2UDDu7mJxd2Mo/muQy81m/2X9I7+N8R/FcPmUd76zjH7X/w4AAP//90glTw==', 2)) + time.sleep(0.1) + assert self._feature_flag_added[0].name == 'bilal_split' + assert telemetry_storage._counters._update_from_sse['sp'] == 2 + + # compression 1 + self._feature_flag_added = None + self._feature_flag_deleted = None + q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456790, 2345, 'H4sIAAkVZWQC/8WST0+DQBDFv0qzZ0ig/BF6a2xjGismUk2MaZopzOKmy9Isy0EbvrtDwbY2Xo233Tdv5se85cCMBs5FtvrYYwIlsglratTMYiKns+chcAgc24UwsF0Xczt2cm5z8Jw8DmPH9wPyqr5zKyTITb2XwpA4TJ5KWWVgRKXYxHWcX/QUkVi264W+68bjaGyxupdCJ4i9KPI9UgyYpibI9Ha1eJnT/J2QsnNxkDVaLEcOjTQrjWBKVIasFefky95BFZg05Zb2mrhh5I9vgsiL44BAIIuKTeiQVYqLotHHLyLOoT1quRjub4fztQuLxj89LpePzytClGCyd9R3umr21ErOcitUh2PTZHY29HN2+JGixMxUujNfvMB3+u2pY1AXySad3z3Mk46msACDp8W7jhly4uUpFt3qD33vDAx0gLpXkx+P1GusbdcE24M2F4uaywwVEWvxSa1Oa13Vjvn2RXradm0xCVuUVBJqNCBGV0DrX4OcLpeb+/lreh3jH8Uw/JQj3UhkxPgCCurdEnADAAA=', 1)) + time.sleep(0.1) + assert self._feature_flag_added[0].name == 'bilal_split' + assert telemetry_storage._counters._update_from_sse['sp'] == 3 + + # should call delete split + self._feature_flag_added = None + self._feature_flag_deleted = None + q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456790, 2345, 'eyJ0cmFmZmljVHlwZU5hbWUiOiAidXNlciIsICJpZCI6ICIzM2VhZmE1MC0xYTY1LTExZWQtOTBkZi1mYTMwZDk2OTA0NDUiLCAibmFtZSI6ICJiaWxhbF9zcGxpdCIsICJ0cmFmZmljQWxsb2NhdGlvbiI6IDEwMCwgInRyYWZmaWNBbGxvY2F0aW9uU2VlZCI6IC0xMzY0MTE5MjgyLCAic2VlZCI6IC02MDU5Mzg4NDMsICJzdGF0dXMiOiAiQVJDSElWRUQiLCAia2lsbGVkIjogZmFsc2UsICJkZWZhdWx0VHJlYXRtZW50IjogIm9mZiIsICJjaGFuZ2VOdW1iZXIiOiAxNjg0Mjc1ODM5OTUyLCAiYWxnbyI6IDIsICJjb25maWd1cmF0aW9ucyI6IHt9LCAiY29uZGl0aW9ucyI6IFt7ImNvbmRpdGlvblR5cGUiOiAiUk9MTE9VVCIsICJtYXRjaGVyR3JvdXAiOiB7ImNvbWJpbmVyIjogIkFORCIsICJtYXRjaGVycyI6IFt7ImtleVNlbGVjdG9yIjogeyJ0cmFmZmljVHlwZSI6ICJ1c2VyIn0sICJtYXRjaGVyVHlwZSI6ICJJTl9TRUdNRU5UIiwgIm5lZ2F0ZSI6IGZhbHNlLCAidXNlckRlZmluZWRTZWdtZW50TWF0Y2hlckRhdGEiOiB7InNlZ21lbnROYW1lIjogImJpbGFsX3NlZ21lbnQifX1dfSwgInBhcnRpdGlvbnMiOiBbeyJ0cmVhdG1lbnQiOiAib24iLCAic2l6ZSI6IDB9LCB7InRyZWF0bWVudCI6ICJvZmYiLCAic2l6ZSI6IDEwMH1dLCAibGFiZWwiOiAiaW4gc2VnbWVudCBiaWxhbF9zZWdtZW50In0sIHsiY29uZGl0aW9uVHlwZSI6ICJST0xMT1VUIiwgIm1hdGNoZXJHcm91cCI6IHsiY29tYmluZXIiOiAiQU5EIiwgIm1hdGNoZXJzIjogW3sia2V5U2VsZWN0b3IiOiB7InRyYWZmaWNUeXBlIjogInVzZXIifSwgIm1hdGNoZXJUeXBlIjogIkFMTF9LRVlTIiwgIm5lZ2F0ZSI6IGZhbHNlfV19LCAicGFydGl0aW9ucyI6IFt7InRyZWF0bWVudCI6ICJvbiIsICJzaXplIjogMH0sIHsidHJlYXRtZW50IjogIm9mZiIsICJzaXplIjogMTAwfV0sICJsYWJlbCI6ICJkZWZhdWx0IHJ1bGUifV19', 0)) + time.sleep(0.1) + assert self._feature_flag_deleted[0] == 'bilal_split' + assert self._feature_flag_added == [] + + def test_edge_cases(self, mocker): + q = queue.Queue() + split_worker = SplitWorker(handler_sync, mocker.Mock(), q, mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) global change_number_received - assert not split_worker.is_running() split_worker.start() - assert split_worker.is_running() - q.put(SplitChangeNotification('some', 'SPLIT_UPDATE', 123456789)) + def get_change_number(): + return 2345 + split_worker._feature_flag_storage.get_change_number = get_change_number + + self._feature_flag_added = None + self._feature_flag_deleted = None + def update(feature_flag_add, feature_flag_delete, change_number): + self._feature_flag_added = feature_flag_add + self._feature_flag_deleted = feature_flag_delete + split_worker._feature_flag_storage.update = update + split_worker._feature_flag_storage.config_flag_sets_used = 0 + + # should Not call the handler + self._feature_flag_added = None + change_number_received = 0 + q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456, 2345, "/2X9I7+N8R/FcPmUd76zjH7X/w4AAP//90glTw==", 2)) + time.sleep(0.1) + assert self._feature_flag_added == None + + # should Not call the handler + self._feature_flag = None + change_number_received = 0 + q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456, 2345, "/2X9I7+N8R/FcPmUd76zjH7X/w4AAP//90glTw==", 4)) + time.sleep(0.1) + assert self._feature_flag_added == None + + # should Not call the handler + self._feature_flag = None + change_number_received = 0 + q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456, None, 'eJzEUtFq20AQ/JUwz2c4WZZr3ZupTQh1FKjcQinGrKU95cjpZE6nh9To34ssJ3FNX0sfd3Zm53b2TgietDbF9vXIGdUMha5lDwFTQiGOmTQlchLRPJlEEZeTVJZ6oimWZTpP5WyWQMCNyoOxZPft0ZoA8TZ5aW1TUDCNg4qk/AueM5dQkyiez6IonS6mAu0IzWWSxovFLBZoA4WuhcLy8/bh+xoCL8bagaXJtixQsqbOhq1nCjW7AIVGawgUz+Qqzrr6wB4qmi9m00/JIk7TZCpAtmqgpgJF47SpOn9+UQt16s9YaS71z9NHOYQFha9Pm83Tty0EagrFM/t733RHqIFZH4wb7LDMVh+Ecc4Lv+ZsuQiNH8hXF3hLv39XXNCHbJ+v7x/X2eDmuKLA74sPihVr47jMuRpWfxy1Kwo0GLQjmv1xpBFD3+96gSP5cLVouM7QQaA1vxhK9uKmd853bEZS9jsBSwe2UDDu7mJxd2Mo/muQy81m/2X9I7+N8R/FcPmUd76zjH7X/w4AAP//90glTw==', 2)) + time.sleep(0.1) + assert self._feature_flag_added == None + # should Not call the handler + self._feature_flag = None + change_number_received = 0 + q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456, 2345, None, 1)) time.sleep(0.1) + assert self._feature_flag_added == None + + def test_fetch_segment(self, mocker): + q = queue.Queue() + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + segment_storage = InMemorySegmentStorage(events_queue) + + self.segment_name = None + def segment_handler_sync(segment_name, change_number): + self.segment_name = segment_name + return + split_worker = SplitWorker(handler_sync, segment_handler_sync, q, split_storage, segment_storage, mocker.Mock(), mocker.Mock()) + split_worker.start() + + def get_change_number(): + return 2345 + split_worker._feature_flag_storage.get_change_number = get_change_number + + def check_instant_ff_update(event): + return True + split_worker._check_instant_ff_update = check_instant_ff_update + + q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 1675095324253, 2345, 'eyJjaGFuZ2VOdW1iZXIiOiAxNjc1MDk1MzI0MjUzLCAidHJhZmZpY1R5cGVOYW1lIjogInVzZXIiLCAibmFtZSI6ICJiaWxhbF9zcGxpdCIsICJ0cmFmZmljQWxsb2NhdGlvbiI6IDEwMCwgInRyYWZmaWNBbGxvY2F0aW9uU2VlZCI6IC0xMzY0MTE5MjgyLCAic2VlZCI6IC02MDU5Mzg4NDMsICJzdGF0dXMiOiAiQUNUSVZFIiwgImtpbGxlZCI6IGZhbHNlLCAiZGVmYXVsdFRyZWF0bWVudCI6ICJvZmYiLCAiYWxnbyI6IDIsICJjb25kaXRpb25zIjogW3siY29uZGl0aW9uVHlwZSI6ICJST0xMT1VUIiwgIm1hdGNoZXJHcm91cCI6IHsiY29tYmluZXIiOiAiQU5EIiwgIm1hdGNoZXJzIjogW3sia2V5U2VsZWN0b3IiOiB7InRyYWZmaWNUeXBlIjogInVzZXIiLCAiYXR0cmlidXRlIjogbnVsbH0sICJtYXRjaGVyVHlwZSI6ICJJTl9TRUdNRU5UIiwgIm5lZ2F0ZSI6IGZhbHNlLCAidXNlckRlZmluZWRTZWdtZW50TWF0Y2hlckRhdGEiOiB7InNlZ21lbnROYW1lIjogImJpbGFsX3NlZ21lbnQifSwgIndoaXRlbGlzdE1hdGNoZXJEYXRhIjogbnVsbCwgInVuYXJ5TnVtZXJpY01hdGNoZXJEYXRhIjogbnVsbCwgImJldHdlZW5NYXRjaGVyRGF0YSI6IG51bGwsICJkZXBlbmRlbmN5TWF0Y2hlckRhdGEiOiBudWxsLCAiYm9vbGVhbk1hdGNoZXJEYXRhIjogbnVsbCwgInN0cmluZ01hdGNoZXJEYXRhIjogbnVsbH1dfSwgInBhcnRpdGlvbnMiOiBbeyJ0cmVhdG1lbnQiOiAib24iLCAic2l6ZSI6IDB9LCB7InRyZWF0bWVudCI6ICJvZmYiLCAic2l6ZSI6IDEwMH1dLCAibGFiZWwiOiAiaW4gc2VnbWVudCBiaWxhbF9zZWdtZW50In0sIHsiY29uZGl0aW9uVHlwZSI6ICJST0xMT1VUIiwgIm1hdGNoZXJHcm91cCI6IHsiY29tYmluZXIiOiAiQU5EIiwgIm1hdGNoZXJzIjogW3sia2V5U2VsZWN0b3IiOiB7InRyYWZmaWNUeXBlIjogInVzZXIiLCAiYXR0cmlidXRlIjogbnVsbH0sICJtYXRjaGVyVHlwZSI6ICJBTExfS0VZUyIsICJuZWdhdGUiOiBmYWxzZSwgInVzZXJEZWZpbmVkU2VnbWVudE1hdGNoZXJEYXRhIjogbnVsbCwgIndoaXRlbGlzdE1hdGNoZXJEYXRhIjogbnVsbCwgInVuYXJ5TnVtZXJpY01hdGNoZXJEYXRhIjogbnVsbCwgImJldHdlZW5NYXRjaGVyRGF0YSI6IG51bGwsICJkZXBlbmRlbmN5TWF0Y2hlckRhdGEiOiBudWxsLCAiYm9vbGVhbk1hdGNoZXJEYXRhIjogbnVsbCwgInN0cmluZ01hdGNoZXJEYXRhIjogbnVsbH1dfSwgInBhcnRpdGlvbnMiOiBbeyJ0cmVhdG1lbnQiOiAib24iLCAic2l6ZSI6IDUwfSwgeyJ0cmVhdG1lbnQiOiAib2ZmIiwgInNpemUiOiA1MH1dLCAibGFiZWwiOiAiZGVmYXVsdCBydWxlIn1dLCAiY29uZmlndXJhdGlvbnMiOiB7fX0=', 0)) + time.sleep(0.1) + assert self.segment_name == "bilal_segment" + +class SplitWorkerAsyncTests(object): + + @pytest.mark.asyncio + async def test_on_error(self, mocker): + q = asyncio.Queue() + + def handler_sync(change_number): + raise APIException('some') + + split_worker = SplitWorkerAsync(handler_async, mocker.Mock(), q, mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + split_worker.start() + assert split_worker.is_running() + + await q.put(SplitChangeNotification('some', 'SPLIT_UPDATE', 123456789)) + with pytest.raises(Exception): + split_worker._handler() + + assert split_worker.is_running() + assert(self._worker_running()) + + await split_worker.stop() + await asyncio.sleep(.1) + + assert not split_worker.is_running() + assert(not self._worker_running()) + + def _worker_running(self): + worker_running = False + for task in asyncio.all_tasks(): + if task._coro.cr_code.co_name == '_run' and not task.done(): + worker_running = True + break + return worker_running + + @pytest.mark.asyncio + async def test_handler(self, mocker): + q = asyncio.Queue() + split_worker = SplitWorkerAsync(handler_async, mocker.Mock(), q, mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + + assert not split_worker.is_running() + split_worker.start() + assert split_worker.is_running() + assert(self._worker_running()) + + global change_number_received + global rbs_change_number_received + + # should call the handler + await q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456789, None, None, None)) + await asyncio.sleep(0.1) assert change_number_received == 123456789 - split_worker.stop() + async def get_change_number(): + return 2345 + split_worker._feature_flag_storage.get_change_number = get_change_number + + async def get_rbs_change_number(): + return 2345 + split_worker._rule_based_segment_storage.get_change_number = get_rbs_change_number + + self.new_change_number = 0 + self._feature_flag_added = None + self._feature_flag_deleted = None + async def update(feature_flag_add, feature_flag_delete, change_number): + self._feature_flag_added = feature_flag_add + self._feature_flag_deleted = feature_flag_delete + self.new_change_number = change_number + split_worker._feature_flag_storage.update = update + split_worker._feature_flag_storage.config_flag_sets_used = 0 + + async def get(segment_name): + return {} + split_worker._segment_storage.get = get + + async def record_update_from_sse(xx): + pass + split_worker._telemetry_runtime_producer.record_update_from_sse = record_update_from_sse + + self._rbs_added = None + self._rbs_deleted = None + async def update_rbs(rbs_add, rbs_delete, change_number): + self._rbs_added = rbs_add + self._rbs_deleted = rbs_delete + split_worker._rule_based_segment_storage.update = update_rbs + + # should not call the handler + rbs_change_number_received = 0 + rbs1 = str(rbs) + rbs1 = rbs1.replace("'", "\"") + rbs1 = rbs1.replace("False", "false") + encoded = base64.b64encode(bytes(rbs1, "utf-8")) + await q.put(RBSChangeUpdate('some', 'RB_SEGMENT_UPDATE', 123456790, 2345, encoded, 0)) + await asyncio.sleep(0.1) + assert rbs_change_number_received == 0 + assert self._rbs_added[0].name == "sample_rule_based_segment" + + # should call the handler + await q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456790, 12345, "{}", 1)) + await asyncio.sleep(0.1) + assert change_number_received == 123456790 + + # should call the handler + change_number_received = 0 + await q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456790, 12345, "{}", 3)) + await asyncio.sleep(0.1) + assert change_number_received == 123456790 + + # should Not call the handler + change_number_received = 0 + await q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456, 2345, "eJzEUtFq20AQ/JUwz2c4WZZr3ZupTQh1FKjcQinGrKU95cjpZE6nh9To34ssJ3FNX0sfd3Zm53b2TgietDbF9vXIGdUMha5lDwFTQiGOmTQlchLRPJlEEZeTVJZ6oimWZTpP5WyWQMCNyoOxZPft0ZoA8TZ5aW1TUDCNg4qk/AueM5dQkyiez6IonS6mAu0IzWWSxovFLBZoA4WuhcLy8/bh+xoCL8bagaXJtixQsqbOhq1nCjW7AIVGawgUz+Qqzrr6wB4qmi9m00/JIk7TZCpAtmqgpgJF47SpOn9+UQt16s9YaS71z9NHOYQFha9Pm83Tty0EagrFM/t733RHqIFZH4wb7LDMVh+Ecc4Lv+ZsuQiNH8hXF3hLv39XXNCHbJ+v7x/X2eDmuKLA74sPihVr47jMuRpWfxy1Kwo0GLQjmv1xpBFD3+96gSP5cLVouM7QQaA1vxhK9uKmd853bEZS9jsBSwe2UDDu7mJxd2Mo/muQy81m/2X9I7+N8R/FcPmUd76zjH7X/w4AAP//90glTw==", 2)) + await asyncio.sleep(0.5) + assert change_number_received == 0 + + await split_worker.stop() + await asyncio.sleep(.1) + assert not split_worker.is_running() + assert(not self._worker_running()) + + @pytest.mark.asyncio + async def test_compression(self, mocker): + q = asyncio.Queue() + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + split_worker = SplitWorkerAsync(handler_async, mocker.Mock(), q, mocker.Mock(), mocker.Mock(), telemetry_runtime_producer, mocker.Mock()) + global change_number_received + split_worker.start() + async def get_change_number(): + return 2345 + split_worker._feature_flag_storage.get_change_number = get_change_number + + async def get(segment_name): + return {} + split_worker._segment_storage.get = get + + async def get_split(feature_flag_name): + return {} + split_worker._feature_flag_storage.get = get_split + + self.new_change_number = 0 + self._feature_flag_added = None + self._feature_flag_deleted = None + async def update(feature_flag_add, feature_flag_delete, change_number): + self._feature_flag_added = feature_flag_add + self._feature_flag_deleted = feature_flag_delete + self.new_change_number = change_number + split_worker._feature_flag_storage.update = update + split_worker._feature_flag_storage.config_flag_sets_used = 0 + + async def contains(rbs): + return False + split_worker._rule_based_segment_storage.contains = contains + + # compression 0 + await q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456790, 2345, 'eyJ0cmFmZmljVHlwZU5hbWUiOiJ1c2VyIiwiaWQiOiIzM2VhZmE1MC0xYTY1LTExZWQtOTBkZi1mYTMwZDk2OTA0NDUiLCJuYW1lIjoiYmlsYWxfc3BsaXQiLCJ0cmFmZmljQWxsb2NhdGlvbiI6MTAwLCJ0cmFmZmljQWxsb2NhdGlvblNlZWQiOi0xMzY0MTE5MjgyLCJzZWVkIjotNjA1OTM4ODQzLCJzdGF0dXMiOiJBQ1RJVkUiLCJraWxsZWQiOmZhbHNlLCJkZWZhdWx0VHJlYXRtZW50Ijoib2ZmIiwiY2hhbmdlTnVtYmVyIjoxNjg0MzQwOTA4NDc1LCJhbGdvIjoyLCJjb25maWd1cmF0aW9ucyI6e30sImNvbmRpdGlvbnMiOlt7ImNvbmRpdGlvblR5cGUiOiJST0xMT1VUIiwibWF0Y2hlckdyb3VwIjp7ImNvbWJpbmVyIjoiQU5EIiwibWF0Y2hlcnMiOlt7ImtleVNlbGVjdG9yIjp7InRyYWZmaWNUeXBlIjoidXNlciJ9LCJtYXRjaGVyVHlwZSI6IklOX1NFR01FTlQiLCJuZWdhdGUiOmZhbHNlLCJ1c2VyRGVmaW5lZFNlZ21lbnRNYXRjaGVyRGF0YSI6eyJzZWdtZW50TmFtZSI6ImJpbGFsX3NlZ21lbnQifX1dfSwicGFydGl0aW9ucyI6W3sidHJlYXRtZW50Ijoib24iLCJzaXplIjowfSx7InRyZWF0bWVudCI6Im9mZiIsInNpemUiOjEwMH1dLCJsYWJlbCI6ImluIHNlZ21lbnQgYmlsYWxfc2VnbWVudCJ9LHsiY29uZGl0aW9uVHlwZSI6IlJPTExPVVQiLCJtYXRjaGVyR3JvdXAiOnsiY29tYmluZXIiOiJBTkQiLCJtYXRjaGVycyI6W3sia2V5U2VsZWN0b3IiOnsidHJhZmZpY1R5cGUiOiJ1c2VyIn0sIm1hdGNoZXJUeXBlIjoiQUxMX0tFWVMiLCJuZWdhdGUiOmZhbHNlfV19LCJwYXJ0aXRpb25zIjpbeyJ0cmVhdG1lbnQiOiJvbiIsInNpemUiOjB9LHsidHJlYXRtZW50Ijoib2ZmIiwic2l6ZSI6MTAwfV0sImxhYmVsIjoiZGVmYXVsdCBydWxlIn1dfQ==', 0)) + await asyncio.sleep(0.1) + assert self._feature_flag_added[0].name == 'bilal_split' + assert telemetry_storage._counters._update_from_sse['sp'] == 1 + + # compression 2 + self._feature_flag_added = None + await q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456790, 2345, 'eJzEUtFq20AQ/JUwz2c4WZZr3ZupTQh1FKjcQinGrKU95cjpZE6nh9To34ssJ3FNX0sfd3Zm53b2TgietDbF9vXIGdUMha5lDwFTQiGOmTQlchLRPJlEEZeTVJZ6oimWZTpP5WyWQMCNyoOxZPft0ZoA8TZ5aW1TUDCNg4qk/AueM5dQkyiez6IonS6mAu0IzWWSxovFLBZoA4WuhcLy8/bh+xoCL8bagaXJtixQsqbOhq1nCjW7AIVGawgUz+Qqzrr6wB4qmi9m00/JIk7TZCpAtmqgpgJF47SpOn9+UQt16s9YaS71z9NHOYQFha9Pm83Tty0EagrFM/t733RHqIFZH4wb7LDMVh+Ecc4Lv+ZsuQiNH8hXF3hLv39XXNCHbJ+v7x/X2eDmuKLA74sPihVr47jMuRpWfxy1Kwo0GLQjmv1xpBFD3+96gSP5cLVouM7QQaA1vxhK9uKmd853bEZS9jsBSwe2UDDu7mJxd2Mo/muQy81m/2X9I7+N8R/FcPmUd76zjH7X/w4AAP//90glTw==', 2)) + await asyncio.sleep(0.1) + assert self._feature_flag_added[0].name == 'bilal_split' + assert telemetry_storage._counters._update_from_sse['sp'] == 2 + + # compression 1 + self._feature_flag_added = None + await q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456790, 2345, 'H4sIAAkVZWQC/8WST0+DQBDFv0qzZ0ig/BF6a2xjGismUk2MaZopzOKmy9Isy0EbvrtDwbY2Xo233Tdv5se85cCMBs5FtvrYYwIlsglratTMYiKns+chcAgc24UwsF0Xczt2cm5z8Jw8DmPH9wPyqr5zKyTITb2XwpA4TJ5KWWVgRKXYxHWcX/QUkVi264W+68bjaGyxupdCJ4i9KPI9UgyYpibI9Ha1eJnT/J2QsnNxkDVaLEcOjTQrjWBKVIasFefky95BFZg05Zb2mrhh5I9vgsiL44BAIIuKTeiQVYqLotHHLyLOoT1quRjub4fztQuLxj89LpePzytClGCyd9R3umr21ErOcitUh2PTZHY29HN2+JGixMxUujNfvMB3+u2pY1AXySad3z3Mk46msACDp8W7jhly4uUpFt3qD33vDAx0gLpXkx+P1GusbdcE24M2F4uaywwVEWvxSa1Oa13Vjvn2RXradm0xCVuUVBJqNCBGV0DrX4OcLpeb+/lreh3jH8Uw/JQj3UhkxPgCCurdEnADAAA=', 1)) + await asyncio.sleep(0.1) + assert self._feature_flag_added[0].name == 'bilal_split' + assert telemetry_storage._counters._update_from_sse['sp'] == 3 + + # should call delete split + self._feature_flag_added = None + self._feature_flag_deleted = None + await q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456790, 2345, 'eyJ0cmFmZmljVHlwZU5hbWUiOiAidXNlciIsICJpZCI6ICIzM2VhZmE1MC0xYTY1LTExZWQtOTBkZi1mYTMwZDk2OTA0NDUiLCAibmFtZSI6ICJiaWxhbF9zcGxpdCIsICJ0cmFmZmljQWxsb2NhdGlvbiI6IDEwMCwgInRyYWZmaWNBbGxvY2F0aW9uU2VlZCI6IC0xMzY0MTE5MjgyLCAic2VlZCI6IC02MDU5Mzg4NDMsICJzdGF0dXMiOiAiQVJDSElWRUQiLCAia2lsbGVkIjogZmFsc2UsICJkZWZhdWx0VHJlYXRtZW50IjogIm9mZiIsICJjaGFuZ2VOdW1iZXIiOiAxNjg0Mjc1ODM5OTUyLCAiYWxnbyI6IDIsICJjb25maWd1cmF0aW9ucyI6IHt9LCAiY29uZGl0aW9ucyI6IFt7ImNvbmRpdGlvblR5cGUiOiAiUk9MTE9VVCIsICJtYXRjaGVyR3JvdXAiOiB7ImNvbWJpbmVyIjogIkFORCIsICJtYXRjaGVycyI6IFt7ImtleVNlbGVjdG9yIjogeyJ0cmFmZmljVHlwZSI6ICJ1c2VyIn0sICJtYXRjaGVyVHlwZSI6ICJJTl9TRUdNRU5UIiwgIm5lZ2F0ZSI6IGZhbHNlLCAidXNlckRlZmluZWRTZWdtZW50TWF0Y2hlckRhdGEiOiB7InNlZ21lbnROYW1lIjogImJpbGFsX3NlZ21lbnQifX1dfSwgInBhcnRpdGlvbnMiOiBbeyJ0cmVhdG1lbnQiOiAib24iLCAic2l6ZSI6IDB9LCB7InRyZWF0bWVudCI6ICJvZmYiLCAic2l6ZSI6IDEwMH1dLCAibGFiZWwiOiAiaW4gc2VnbWVudCBiaWxhbF9zZWdtZW50In0sIHsiY29uZGl0aW9uVHlwZSI6ICJST0xMT1VUIiwgIm1hdGNoZXJHcm91cCI6IHsiY29tYmluZXIiOiAiQU5EIiwgIm1hdGNoZXJzIjogW3sia2V5U2VsZWN0b3IiOiB7InRyYWZmaWNUeXBlIjogInVzZXIifSwgIm1hdGNoZXJUeXBlIjogIkFMTF9LRVlTIiwgIm5lZ2F0ZSI6IGZhbHNlfV19LCAicGFydGl0aW9ucyI6IFt7InRyZWF0bWVudCI6ICJvbiIsICJzaXplIjogMH0sIHsidHJlYXRtZW50IjogIm9mZiIsICJzaXplIjogMTAwfV0sICJsYWJlbCI6ICJkZWZhdWx0IHJ1bGUifV19', 0)) + await asyncio.sleep(0.1) + assert self._feature_flag_deleted[0] == 'bilal_split' + assert self._feature_flag_added == [] + + await split_worker.stop() + + @pytest.mark.asyncio + async def test_edge_cases(self, mocker): + q = asyncio.Queue() + split_worker = SplitWorkerAsync(handler_async, mocker.Mock(), q, mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + global change_number_received + split_worker.start() + + async def get_change_number(): + return 2345 + split_worker._feature_flag_storage.get_change_number = get_change_number + + self._feature_flag_added = None + self._feature_flag_deleted = None + async def update(feature_flag_add, feature_flag_delete, change_number): + self._feature_flag_added = feature_flag_add + self._feature_flag_deleted = feature_flag_delete + self.new_change_number = change_number + split_worker._feature_flag_storage.update = update + split_worker._feature_flag_storage.config_flag_sets_used = 0 + + # should Not call the handler + self._feature_flag_added = None + change_number_received = 0 + await q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456, 2345, "/2X9I7+N8R/FcPmUd76zjH7X/w4AAP//90glTw==", 2)) + await asyncio.sleep(0.1) + assert self._feature_flag_added == None + + + # should Not call the handler + self._feature_flag_added = None + change_number_received = 0 + await q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456, 2345, "/2X9I7+N8R/FcPmUd76zjH7X/w4AAP//90glTw==", 4)) + await asyncio.sleep(0.1) + assert self._feature_flag_added == None + + # should Not call the handler + self._feature_flag_added = None + change_number_received = 0 + await q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456, None, 'eJzEUtFq20AQ/JUwz2c4WZZr3ZupTQh1FKjcQinGrKU95cjpZE6nh9To34ssJ3FNX0sfd3Zm53b2TgietDbF9vXIGdUMha5lDwFTQiGOmTQlchLRPJlEEZeTVJZ6oimWZTpP5WyWQMCNyoOxZPft0ZoA8TZ5aW1TUDCNg4qk/AueM5dQkyiez6IonS6mAu0IzWWSxovFLBZoA4WuhcLy8/bh+xoCL8bagaXJtixQsqbOhq1nCjW7AIVGawgUz+Qqzrr6wB4qmi9m00/JIk7TZCpAtmqgpgJF47SpOn9+UQt16s9YaS71z9NHOYQFha9Pm83Tty0EagrFM/t733RHqIFZH4wb7LDMVh+Ecc4Lv+ZsuQiNH8hXF3hLv39XXNCHbJ+v7x/X2eDmuKLA74sPihVr47jMuRpWfxy1Kwo0GLQjmv1xpBFD3+96gSP5cLVouM7QQaA1vxhK9uKmd853bEZS9jsBSwe2UDDu7mJxd2Mo/muQy81m/2X9I7+N8R/FcPmUd76zjH7X/w4AAP//90glTw==', 2)) + await asyncio.sleep(0.1) + assert self._feature_flag_added == None + + # should Not call the handler + self._feature_flag_added = None + change_number_received = 0 + await q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 123456, 2345, None, 1)) + await asyncio.sleep(0.1) + assert self._feature_flag_added == None + + await split_worker.stop() + + @pytest.mark.asyncio + async def test_fetch_segment(self, mocker): + q = asyncio.Queue() + internal_events_queue = asyncio.Queue() + split_storage = InMemorySplitStorageAsync(internal_events_queue) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + + self.segment_name = None + async def segment_handler_sync(segment_name, change_number): + self.segment_name = segment_name + return + split_worker = SplitWorkerAsync(handler_async, segment_handler_sync, q, split_storage, segment_storage, mocker.Mock(), mocker.Mock()) + split_worker.start() + + async def get_change_number(): + return 2345 + split_worker._feature_flag_storage.get_change_number = get_change_number + + async def check_instant_ff_update(event): + return True + split_worker._check_instant_ff_update = check_instant_ff_update + + await q.put(SplitChangeUpdate('some', 'SPLIT_UPDATE', 1675095324253, 2345, 'eyJjaGFuZ2VOdW1iZXIiOiAxNjc1MDk1MzI0MjUzLCAidHJhZmZpY1R5cGVOYW1lIjogInVzZXIiLCAibmFtZSI6ICJiaWxhbF9zcGxpdCIsICJ0cmFmZmljQWxsb2NhdGlvbiI6IDEwMCwgInRyYWZmaWNBbGxvY2F0aW9uU2VlZCI6IC0xMzY0MTE5MjgyLCAic2VlZCI6IC02MDU5Mzg4NDMsICJzdGF0dXMiOiAiQUNUSVZFIiwgImtpbGxlZCI6IGZhbHNlLCAiZGVmYXVsdFRyZWF0bWVudCI6ICJvZmYiLCAiYWxnbyI6IDIsICJjb25kaXRpb25zIjogW3siY29uZGl0aW9uVHlwZSI6ICJST0xMT1VUIiwgIm1hdGNoZXJHcm91cCI6IHsiY29tYmluZXIiOiAiQU5EIiwgIm1hdGNoZXJzIjogW3sia2V5U2VsZWN0b3IiOiB7InRyYWZmaWNUeXBlIjogInVzZXIiLCAiYXR0cmlidXRlIjogbnVsbH0sICJtYXRjaGVyVHlwZSI6ICJJTl9TRUdNRU5UIiwgIm5lZ2F0ZSI6IGZhbHNlLCAidXNlckRlZmluZWRTZWdtZW50TWF0Y2hlckRhdGEiOiB7InNlZ21lbnROYW1lIjogImJpbGFsX3NlZ21lbnQifSwgIndoaXRlbGlzdE1hdGNoZXJEYXRhIjogbnVsbCwgInVuYXJ5TnVtZXJpY01hdGNoZXJEYXRhIjogbnVsbCwgImJldHdlZW5NYXRjaGVyRGF0YSI6IG51bGwsICJkZXBlbmRlbmN5TWF0Y2hlckRhdGEiOiBudWxsLCAiYm9vbGVhbk1hdGNoZXJEYXRhIjogbnVsbCwgInN0cmluZ01hdGNoZXJEYXRhIjogbnVsbH1dfSwgInBhcnRpdGlvbnMiOiBbeyJ0cmVhdG1lbnQiOiAib24iLCAic2l6ZSI6IDB9LCB7InRyZWF0bWVudCI6ICJvZmYiLCAic2l6ZSI6IDEwMH1dLCAibGFiZWwiOiAiaW4gc2VnbWVudCBiaWxhbF9zZWdtZW50In0sIHsiY29uZGl0aW9uVHlwZSI6ICJST0xMT1VUIiwgIm1hdGNoZXJHcm91cCI6IHsiY29tYmluZXIiOiAiQU5EIiwgIm1hdGNoZXJzIjogW3sia2V5U2VsZWN0b3IiOiB7InRyYWZmaWNUeXBlIjogInVzZXIiLCAiYXR0cmlidXRlIjogbnVsbH0sICJtYXRjaGVyVHlwZSI6ICJBTExfS0VZUyIsICJuZWdhdGUiOiBmYWxzZSwgInVzZXJEZWZpbmVkU2VnbWVudE1hdGNoZXJEYXRhIjogbnVsbCwgIndoaXRlbGlzdE1hdGNoZXJEYXRhIjogbnVsbCwgInVuYXJ5TnVtZXJpY01hdGNoZXJEYXRhIjogbnVsbCwgImJldHdlZW5NYXRjaGVyRGF0YSI6IG51bGwsICJkZXBlbmRlbmN5TWF0Y2hlckRhdGEiOiBudWxsLCAiYm9vbGVhbk1hdGNoZXJEYXRhIjogbnVsbCwgInN0cmluZ01hdGNoZXJEYXRhIjogbnVsbH1dfSwgInBhcnRpdGlvbnMiOiBbeyJ0cmVhdG1lbnQiOiAib24iLCAic2l6ZSI6IDUwfSwgeyJ0cmVhdG1lbnQiOiAib2ZmIiwgInNpemUiOiA1MH1dLCAibGFiZWwiOiAiZGVmYXVsdCBydWxlIn1dLCAiY29uZmlndXJhdGlvbnMiOiB7fX0=', 0)) + await asyncio.sleep(0.1) + assert self.segment_name == "bilal_segment" + + await split_worker.stop() diff --git a/tests/push/test_splitsse.py b/tests/push/test_splitsse.py index ebb8fa94..c461f9fe 100644 --- a/tests/push/test_splitsse.py +++ b/tests/push/test_splitsse.py @@ -5,16 +5,14 @@ import pytest from splitio.models.token import Token - -from splitio.push.splitsse import SplitSSEClient -from splitio.push.sse import SSEEvent +from splitio.push.splitsse import SplitSSEClient, SplitSSEClientAsync +from splitio.push.sse import SSEEvent, SSE_EVENT_ERROR from tests.helpers.mockserver import SSEMockServer - from splitio.client.util import SdkMetadata +from splitio.optional.loaders import asyncio - -class SSEClientTests(object): +class SSESplitClientTests(object): """SSEClient test cases.""" def test_split_sse_success(self): @@ -124,3 +122,89 @@ def on_disconnect(): assert status['on_connect'] assert status['on_disconnect'] + + +class SSESplitClientAsyncTests(object): + """SSEClientAsync test cases.""" + + @pytest.mark.asyncio + async def test_split_sse_success(self): + """Test correct initialization. Client ends the connection.""" + request_queue = Queue() + server = SSEMockServer(request_queue) + server.start() + + client = SplitSSEClientAsync(SdkMetadata('1.0', 'some', '1.2.3.4'), + 'abcd', base_url='http://localhost:' + str(server.port())) + + token = Token(True, 'some', {'chan1': ['subscribe'], 'chan2': ['subscribe', 'channel-metadata:publishers']}, + 1, 2) + + events_source = client.start(token) + server.publish({'id': '1'}) # send a non-error event early to unblock start + server.publish({'id': '1', 'data': 'a', 'retry': '1', 'event': 'message'}) + server.publish({'id': '2', 'data': 'a', 'retry': '1', 'event': 'message'}) + + first_event = await events_source.__anext__() + assert first_event.event != SSE_EVENT_ERROR + + + event2 = await events_source.__anext__() + event3 = await events_source.__anext__() + + # Since generators are meant to be iterated, we need to consume them all until StopIteration occurs + # to do this, connection must be closed in another coroutine, while the current one is still consuming events. + shutdown_task = asyncio.get_running_loop().create_task(client.stop()) + with pytest.raises(StopAsyncIteration): await events_source.__anext__() + await shutdown_task + + + request = request_queue.get(1) + assert request.path == '/event-stream?v=1.1&accessToken=some&channels=chan1,%5B?occupancy=metrics.publishers%5Dchan2' + assert request.headers['accept'] == 'text/event-stream' + assert request.headers['SplitSDKVersion'] == '1.0' + assert request.headers['SplitSDKMachineIP'] == '1.2.3.4' + assert request.headers['SplitSDKMachineName'] == 'some' + assert request.headers['SplitSDKClientKey'] == 'abcd' + + assert event2 == SSEEvent('1', 'message', '1', 'a') + assert event3 == SSEEvent('2', 'message', '1', 'a') + + server.publish(SSEMockServer.VIOLENT_REQUEST_END) + server.stop() + await asyncio.sleep(1) + + assert client.status == SplitSSEClient._Status.IDLE + + + @pytest.mark.asyncio + async def test_split_sse_error(self): + """Test correct initialization. Client ends the connection.""" + request_queue = Queue() + server = SSEMockServer(request_queue) + server.start() + + client = SplitSSEClientAsync(SdkMetadata('1.0', 'some', '1.2.3.4'), + 'abcd', base_url='http://localhost:' + str(server.port())) + + token = Token(True, 'some', {'chan1': ['subscribe'], 'chan2': ['subscribe', 'channel-metadata:publishers']}, + 1, 2) + + events_source = client.start(token) + server.publish({'event': 'error'}) # send an error event early to unblock start + + + with pytest.raises(StopAsyncIteration): await events_source.__anext__() + + assert client.status == SplitSSEClient._Status.IDLE + + request = request_queue.get(1) + assert request.path == '/event-stream?v=1.1&accessToken=some&channels=chan1,%5B?occupancy=metrics.publishers%5Dchan2' + assert request.headers['accept'] == 'text/event-stream' + assert request.headers['SplitSDKVersion'] == '1.0' + assert request.headers['SplitSDKMachineIP'] == '1.2.3.4' + assert request.headers['SplitSDKMachineName'] == 'some' + assert request.headers['SplitSDKClientKey'] == 'abcd' + + server.publish(SSEMockServer.VIOLENT_REQUEST_END) + server.stop() diff --git a/tests/push/test_sse.py b/tests/push/test_sse.py index 8bba1714..1e0e2e48 100644 --- a/tests/push/test_sse.py +++ b/tests/push/test_sse.py @@ -3,9 +3,11 @@ import time import threading import pytest -from splitio.push.sse import SSEClient, SSEEvent -from tests.helpers.mockserver import SSEMockServer +from contextlib import suppress +from splitio.push.sse import SSEClient, SSEEvent, SSEClientAsync +from splitio.optional.loaders import asyncio +from tests.helpers.mockserver import SSEMockServer class SSEClientTests(object): """SSEClient test cases.""" @@ -26,7 +28,6 @@ def runner(): """SSE client runner thread.""" assert client.start('http://127.0.0.1:' + str(server.port())) client_task = threading.Thread(target=runner) - client_task.setDaemon(True) client_task.setName('client') client_task.start() with pytest.raises(RuntimeError): @@ -65,9 +66,8 @@ def callback(event): def runner(): """SSE client runner thread.""" - assert client.start('http://127.0.0.1:' + str(server.port())) + assert not client.start('http://127.0.0.1:' + str(server.port())) client_task = threading.Thread(target=runner) - client_task.setDaemon(True) client_task.setName('client') client_task.start() @@ -93,7 +93,7 @@ def test_sse_server_disconnects_abruptly(self): """Test correct initialization. Server ends connection.""" server = SSEMockServer() server.start() - + events = [] def callback(event): """Callback.""" @@ -103,9 +103,8 @@ def callback(event): def runner(): """SSE client runner thread.""" - assert client.start('http://127.0.0.1:' + str(server.port())) - client_task = threading.Thread(target=runner) - client_task.setDaemon(True) + assert not client.start('http://127.0.0.1:' + str(server.port())) + client_task = threading.Thread(target=runner, daemon=True) client_task.setName('client') client_task.start() @@ -126,3 +125,105 @@ def runner(): ] assert client._conn is None + +class SSEClientAsyncTests(object): + """SSEClient test cases.""" + + @pytest.mark.asyncio + async def test_sse_client_disconnects(self): + """Test correct initialization. Client ends the connection.""" + server = SSEMockServer() + server.start() + client = SSEClientAsync() + sse_events_loop = client.start(f"http://127.0.0.1:{str(server.port())}?token=abc123$%^&(") + + server.publish({'id': '1'}) + server.publish({'id': '2', 'event': 'message', 'data': 'abc'}) + server.publish({'id': '3', 'event': 'message', 'data': 'def'}) + server.publish({'id': '4', 'event': 'message', 'data': 'ghi'}) + + event1 = await sse_events_loop.__anext__() + event2 = await sse_events_loop.__anext__() + event3 = await sse_events_loop.__anext__() + event4 = await sse_events_loop.__anext__() + + # Since generators are meant to be iterated, we need to consume them all until StopIteration occurs + # to do this, connection must be closed in another coroutine, while the current one is still consuming events. + shutdown_task = asyncio.get_running_loop().create_task(client.shutdown()) + with pytest.raises(StopAsyncIteration): await sse_events_loop.__anext__() + await shutdown_task + + assert event1 == SSEEvent('1', None, None, None) + assert event2 == SSEEvent('2', 'message', None, 'abc') + assert event3 == SSEEvent('3', 'message', None, 'def') + assert event4 == SSEEvent('4', 'message', None, 'ghi') + assert client._response == None + + server.publish(server.GRACEFUL_REQUEST_END) + server.stop() + + @pytest.mark.asyncio + async def test_sse_server_disconnects(self): + """Test correct initialization. Server ends connection.""" + server = SSEMockServer() + server.start() + client = SSEClientAsync() + sse_events_loop = client.start('http://127.0.0.1:' + str(server.port())) + + server.publish({'id': '1'}) + server.publish({'id': '2', 'event': 'message', 'data': 'abc'}) + server.publish({'id': '3', 'event': 'message', 'data': 'def'}) + server.publish({'id': '4', 'event': 'message', 'data': 'ghi'}) + + event1 = await sse_events_loop.__anext__() + event2 = await sse_events_loop.__anext__() + event3 = await sse_events_loop.__anext__() + event4 = await sse_events_loop.__anext__() + + server.publish(server.GRACEFUL_REQUEST_END) + + # after the connection ends, any subsequent read sohould fail and iteration should stop + with pytest.raises(StopAsyncIteration): await sse_events_loop.__anext__() + + assert event1 == SSEEvent('1', None, None, None) + assert event2 == SSEEvent('2', 'message', None, 'abc') + assert event3 == SSEEvent('3', 'message', None, 'def') + assert event4 == SSEEvent('4', 'message', None, 'ghi') + assert client._response == None + + await client._done.wait() # to ensure `start()` has finished + assert client._response is None + +# server.stop() + + + @pytest.mark.asyncio + async def test_sse_server_disconnects_abruptly(self): + """Test correct initialization. Server ends connection.""" + server = SSEMockServer() + server.start() + client = SSEClientAsync() + sse_events_loop = client.start('http://127.0.0.1:' + str(server.port())) + + server.publish({'id': '1'}) + server.publish({'id': '2', 'event': 'message', 'data': 'abc'}) + server.publish({'id': '3', 'event': 'message', 'data': 'def'}) + server.publish({'id': '4', 'event': 'message', 'data': 'ghi'}) + + event1 = await sse_events_loop.__anext__() + event2 = await sse_events_loop.__anext__() + event3 = await sse_events_loop.__anext__() + event4 = await sse_events_loop.__anext__() + + server.publish(server.VIOLENT_REQUEST_END) + with pytest.raises(StopAsyncIteration): await sse_events_loop.__anext__() + + server.stop() + + assert event1 == SSEEvent('1', None, None, None) + assert event2 == SSEEvent('2', 'message', None, 'abc') + assert event3 == SSEEvent('3', 'message', None, 'def') + assert event4 == SSEEvent('4', 'message', None, 'ghi') + + await client._done.wait() # to ensure `start()` has finished + assert client._response is None diff --git a/tests/push/test_status_tracker.py b/tests/push/test_status_tracker.py index abe8da9e..b77bd483 100644 --- a/tests/push/test_status_tracker.py +++ b/tests/push/test_status_tracker.py @@ -1,15 +1,23 @@ """SSE Status tracker unit tests.""" #pylint:disable=protected-access,no-self-use,line-too-long -from splitio.push.status_tracker import PushStatusTracker, Status +import pytest + +from splitio.push.status_tracker import PushStatusTracker, Status, PushStatusTrackerAsync from splitio.push.parser import ControlType, AblyError, OccupancyMessage, ControlMessage +from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemoryTelemetryStorageAsync +from splitio.models.telemetry import StreamingEventTypes, SSEStreamingStatus, SSEConnectionError class StatusTrackerTests(object): """Parser tests.""" - def test_initial_status_and_reset(self): + def test_initial_status_and_reset(self, mocker): """Test the initial status is ok and reset() works as expected.""" - tracker = PushStatusTracker() + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + tracker = PushStatusTracker(telemetry_runtime_producer) assert tracker._occupancy_ok() assert tracker._last_control_message == ControlType.STREAMING_ENABLED assert tracker._last_status_propagated == Status.PUSH_SUBSYSTEM_UP @@ -25,13 +33,18 @@ def test_initial_status_and_reset(self): assert tracker._last_status_propagated == Status.PUSH_SUBSYSTEM_UP assert not tracker._shutdown_expected - def test_handling_occupancy(self): + def test_handling_occupancy(self, mocker): """Test handling occupancy works properly.""" - tracker = PushStatusTracker() + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + tracker = PushStatusTracker(telemetry_runtime_producer) assert tracker._occupancy_ok() message = OccupancyMessage('[?occupancy=metrics.publishers]control_sec', 123, 0) assert tracker.handle_occupancy(message) is None + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.OCCUPANCY_SEC.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == len(tracker._publishers)) # old message message = OccupancyMessage('[?occupancy=metrics.publishers]control_pri', 122, 0) @@ -39,16 +52,25 @@ def test_handling_occupancy(self): message = OccupancyMessage('[?occupancy=metrics.publishers]control_pri', 124, 0) assert tracker.handle_occupancy(message) is Status.PUSH_SUBSYSTEM_DOWN + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.STREAMING_STATUS.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == SSEStreamingStatus.PAUSED.value) message = OccupancyMessage('[?occupancy=metrics.publishers]control_pri', 125, 1) assert tracker.handle_occupancy(message) is Status.PUSH_SUBSYSTEM_UP + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.STREAMING_STATUS.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == SSEStreamingStatus.ENABLED.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-2]._type == StreamingEventTypes.OCCUPANCY_PRI.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-2]._data == len(tracker._publishers)) message = OccupancyMessage('[?occupancy=metrics.publishers]control_sec', 125, 2) assert tracker.handle_occupancy(message) is None - def test_handling_control(self): + def test_handling_control(self, mocker): """Test handling incoming control messages.""" - tracker = PushStatusTracker() + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + tracker = PushStatusTracker(telemetry_runtime_producer) assert tracker._last_control_message == ControlType.STREAMING_ENABLED assert tracker._last_status_propagated == Status.PUSH_SUBSYSTEM_UP @@ -69,7 +91,7 @@ def test_handling_control(self): assert tracker.handle_control_message(message) is Status.PUSH_NONRETRYABLE_ERROR # test that disabling works as well with streaming paused - tracker = PushStatusTracker() + tracker = PushStatusTracker(mocker.Mock()) assert tracker._last_control_message == ControlType.STREAMING_ENABLED assert tracker._last_status_propagated == Status.PUSH_SUBSYSTEM_UP @@ -78,10 +100,13 @@ def test_handling_control(self): message = ControlMessage('control_pri', 126, ControlType.STREAMING_DISABLED) assert tracker.handle_control_message(message) is Status.PUSH_NONRETRYABLE_ERROR + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.STREAMING_STATUS.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == SSEStreamingStatus.DISABLED.value) + - def test_control_occupancy_overlap(self): + def test_control_occupancy_overlap(self, mocker): """Test control and occupancy messages together.""" - tracker = PushStatusTracker() + tracker = PushStatusTracker(mocker.Mock()) assert tracker._last_control_message == ControlType.STREAMING_ENABLED assert tracker._last_status_propagated == Status.PUSH_SUBSYSTEM_UP @@ -100,9 +125,12 @@ def test_control_occupancy_overlap(self): message = OccupancyMessage('[?occupancy=metrics.publishers]control_pri', 126, 1) assert tracker.handle_occupancy(message) is Status.PUSH_SUBSYSTEM_UP - def test_ably_error(self): + def test_ably_error(self, mocker): """Test the status tracker reacts appropriately to an ably error.""" - tracker = PushStatusTracker() + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + tracker = PushStatusTracker(telemetry_runtime_producer) assert tracker._last_control_message == ControlType.STREAMING_ENABLED assert tracker._last_status_propagated == Status.PUSH_SUBSYSTEM_UP @@ -127,10 +155,16 @@ def test_ably_error(self): tracker.reset() message = AblyError(40139, 100, 'some message', 'http://somewhere') assert tracker.handle_ably_error(message) is Status.PUSH_NONRETRYABLE_ERROR + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.ABLY_ERROR.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == 40139) - def test_disconnect_expected(self): + + def test_disconnect_expected(self, mocker): """Test that no error is propagated when a disconnect is expected.""" - tracker = PushStatusTracker() + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + tracker = PushStatusTracker(telemetry_runtime_producer) assert tracker._last_control_message == ControlType.STREAMING_ENABLED assert tracker._last_status_propagated == Status.PUSH_SUBSYSTEM_UP tracker.notify_sse_shutdown_expected() @@ -145,3 +179,217 @@ def test_disconnect_expected(self): assert tracker.handle_occupancy(OccupancyMessage('[?occupancy=metrics.publishers]control_sec', 123, 0)) is None assert tracker.handle_occupancy(OccupancyMessage('[?occupancy=metrics.publishers]control_sec', 124, 1)) is None + + def test_telemetry_non_requested_disconnect(self, mocker): + """Test the initial status is ok and reset() works as expected.""" + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + tracker = PushStatusTracker(telemetry_runtime_producer) + tracker._shutdown_expected = False + tracker.handle_disconnect() + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.SSE_CONNECTION_ERROR.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == SSEConnectionError.NON_REQUESTED.value) + + tracker._shutdown_expected = True + tracker.handle_disconnect() + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.SSE_CONNECTION_ERROR.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == SSEConnectionError.REQUESTED.value) + + +class StatusTrackerAsyncTests(object): + """Parser tests.""" + + @pytest.mark.asyncio + async def test_initial_status_and_reset(self, mocker): + """Test the initial status is ok and reset() works as expected.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + tracker = PushStatusTrackerAsync(telemetry_runtime_producer) + assert tracker._occupancy_ok() + assert tracker._last_control_message == ControlType.STREAMING_ENABLED + assert tracker._last_status_propagated == Status.PUSH_SUBSYSTEM_UP + assert not tracker._shutdown_expected + + tracker._last_control_message = ControlType.STREAMING_PAUSED + tracker._publishers['control_pri'] = 0 + tracker._publishers['control_sec'] = 1 + tracker._last_status_propagated = Status.PUSH_NONRETRYABLE_ERROR + tracker.reset() + assert tracker._occupancy_ok() + assert tracker._last_control_message == ControlType.STREAMING_ENABLED + assert tracker._last_status_propagated == Status.PUSH_SUBSYSTEM_UP + assert not tracker._shutdown_expected + + @pytest.mark.asyncio + async def test_handling_occupancy(self, mocker): + """Test handling occupancy works properly.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + tracker = PushStatusTrackerAsync(telemetry_runtime_producer) + assert tracker._occupancy_ok() + + message = OccupancyMessage('[?occupancy=metrics.publishers]control_sec', 123, 0) + assert await tracker.handle_occupancy(message) is None + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.OCCUPANCY_SEC.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == len(tracker._publishers)) + + # old message + message = OccupancyMessage('[?occupancy=metrics.publishers]control_pri', 122, 0) + assert await tracker.handle_occupancy(message) is None + + message = OccupancyMessage('[?occupancy=metrics.publishers]control_pri', 124, 0) + assert await tracker.handle_occupancy(message) is Status.PUSH_SUBSYSTEM_DOWN + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.STREAMING_STATUS.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == SSEStreamingStatus.PAUSED.value) + + message = OccupancyMessage('[?occupancy=metrics.publishers]control_pri', 125, 1) + assert await tracker.handle_occupancy(message) is Status.PUSH_SUBSYSTEM_UP + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.STREAMING_STATUS.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == SSEStreamingStatus.ENABLED.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-2]._type == StreamingEventTypes.OCCUPANCY_PRI.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-2]._data == len(tracker._publishers)) + + message = OccupancyMessage('[?occupancy=metrics.publishers]control_sec', 125, 2) + assert await tracker.handle_occupancy(message) is None + + @pytest.mark.asyncio + async def test_handling_control(self, mocker): + """Test handling incoming control messages.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + tracker = PushStatusTrackerAsync(telemetry_runtime_producer) + assert tracker._last_control_message == ControlType.STREAMING_ENABLED + assert tracker._last_status_propagated == Status.PUSH_SUBSYSTEM_UP + + message = ControlMessage('control_pri', 123, ControlType.STREAMING_ENABLED) + assert await tracker.handle_control_message(message) is None + + # old message + message = ControlMessage('control_pri', 122, ControlType.STREAMING_PAUSED) + assert await tracker.handle_control_message(message) is None + + message = ControlMessage('control_pri', 124, ControlType.STREAMING_PAUSED) + assert await tracker.handle_control_message(message) is Status.PUSH_SUBSYSTEM_DOWN + + message = ControlMessage('control_pri', 125, ControlType.STREAMING_ENABLED) + assert await tracker.handle_control_message(message) is Status.PUSH_SUBSYSTEM_UP + + message = ControlMessage('control_pri', 126, ControlType.STREAMING_DISABLED) + assert await tracker.handle_control_message(message) is Status.PUSH_NONRETRYABLE_ERROR + + # test that disabling works as well with streaming paused + tracker = PushStatusTrackerAsync(telemetry_runtime_producer) + assert tracker._last_control_message == ControlType.STREAMING_ENABLED + assert tracker._last_status_propagated == Status.PUSH_SUBSYSTEM_UP + + message = ControlMessage('control_pri', 124, ControlType.STREAMING_PAUSED) + assert await tracker.handle_control_message(message) is Status.PUSH_SUBSYSTEM_DOWN + + message = ControlMessage('control_pri', 126, ControlType.STREAMING_DISABLED) + assert await tracker.handle_control_message(message) is Status.PUSH_NONRETRYABLE_ERROR + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.STREAMING_STATUS.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == SSEStreamingStatus.DISABLED.value) + + + @pytest.mark.asyncio + async def test_control_occupancy_overlap(self, mocker): + """Test control and occupancy messages together.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + tracker = PushStatusTrackerAsync(telemetry_runtime_producer) + assert tracker._last_control_message == ControlType.STREAMING_ENABLED + assert tracker._last_status_propagated == Status.PUSH_SUBSYSTEM_UP + + message = ControlMessage('control_pri', 122, ControlType.STREAMING_PAUSED) + assert await tracker.handle_control_message(message) is Status.PUSH_SUBSYSTEM_DOWN + + message = OccupancyMessage('[?occupancy=metrics.publishers]control_sec', 123, 0) + assert await tracker.handle_occupancy(message) is None + + message = OccupancyMessage('[?occupancy=metrics.publishers]control_pri', 124, 0) + assert await tracker.handle_occupancy(message) is None + + message = ControlMessage('control_pri', 125, ControlType.STREAMING_ENABLED) + assert await tracker.handle_control_message(message) is None + + message = OccupancyMessage('[?occupancy=metrics.publishers]control_pri', 126, 1) + assert await tracker.handle_occupancy(message) is Status.PUSH_SUBSYSTEM_UP + + @pytest.mark.asyncio + async def test_ably_error(self, mocker): + """Test the status tracker reacts appropriately to an ably error.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + tracker = PushStatusTrackerAsync(telemetry_runtime_producer) + assert tracker._last_control_message == ControlType.STREAMING_ENABLED + assert tracker._last_status_propagated == Status.PUSH_SUBSYSTEM_UP + + message = AblyError(39999, 100, 'some message', 'http://somewhere') + assert await tracker.handle_ably_error(message) is None + + message = AblyError(50000, 100, 'some message', 'http://somewhere') + assert await tracker.handle_ably_error(message) is None + + tracker.reset() + message = AblyError(40140, 100, 'some message', 'http://somewhere') + assert await tracker.handle_ably_error(message) is Status.PUSH_RETRYABLE_ERROR + + tracker.reset() + message = AblyError(40149, 100, 'some message', 'http://somewhere') + assert await tracker.handle_ably_error(message) is Status.PUSH_RETRYABLE_ERROR + + tracker.reset() + message = AblyError(40150, 100, 'some message', 'http://somewhere') + assert await tracker.handle_ably_error(message) is Status.PUSH_NONRETRYABLE_ERROR + + tracker.reset() + message = AblyError(40139, 100, 'some message', 'http://somewhere') + assert await tracker.handle_ably_error(message) is Status.PUSH_NONRETRYABLE_ERROR + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.ABLY_ERROR.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == 40139) + + + @pytest.mark.asyncio + async def test_disconnect_expected(self, mocker): + """Test that no error is propagated when a disconnect is expected.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + tracker = PushStatusTrackerAsync(telemetry_runtime_producer) + assert tracker._last_control_message == ControlType.STREAMING_ENABLED + assert tracker._last_status_propagated == Status.PUSH_SUBSYSTEM_UP + tracker.notify_sse_shutdown_expected() + + assert await tracker.handle_ably_error(AblyError(40139, 100, 'some message', 'http://somewhere')) is None + assert await tracker.handle_ably_error(AblyError(40149, 100, 'some message', 'http://somewhere')) is None + assert await tracker.handle_ably_error(AblyError(39999, 100, 'some message', 'http://somewhere')) is None + + assert await tracker.handle_control_message(ControlMessage('control_pri', 123, ControlType.STREAMING_ENABLED)) is None + assert await tracker.handle_control_message(ControlMessage('control_pri', 124, ControlType.STREAMING_PAUSED)) is None + assert await tracker.handle_control_message(ControlMessage('control_pri', 125, ControlType.STREAMING_DISABLED)) is None + + assert await tracker.handle_occupancy(OccupancyMessage('[?occupancy=metrics.publishers]control_sec', 123, 0)) is None + assert await tracker.handle_occupancy(OccupancyMessage('[?occupancy=metrics.publishers]control_sec', 124, 1)) is None + + @pytest.mark.asyncio + async def test_telemetry_non_requested_disconnect(self, mocker): + """Test the initial status is ok and reset() works as expected.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + tracker = PushStatusTrackerAsync(telemetry_runtime_producer) + tracker._shutdown_expected = False + await tracker.handle_disconnect() + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.SSE_CONNECTION_ERROR.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == SSEConnectionError.NON_REQUESTED.value) + + tracker._shutdown_expected = True + await tracker.handle_disconnect() + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.SSE_CONNECTION_ERROR.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == SSEConnectionError.REQUESTED.value) diff --git a/tests/recorder/test_recorder.py b/tests/recorder/test_recorder.py index 5e559f82..cf226613 100644 --- a/tests/recorder/test_recorder.py +++ b/tests/recorder/test_recorder.py @@ -2,69 +2,274 @@ import pytest -from splitio.recorder.recorder import StandardRecorder, PipelinedRecorder -from splitio.engine.impressions import Manager as ImpressionsManager -from splitio.storage.inmemmory import EventStorage, ImpressionStorage -from splitio.storage.redis import ImpressionPipelinedStorage, EventStorage -from splitio.storage.adapters.redis import RedisAdapter +from splitio.client.listener import ImpressionListenerWrapper, ImpressionListenerWrapperAsync +from splitio.recorder.recorder import StandardRecorder, PipelinedRecorder, StandardRecorderAsync, PipelinedRecorderAsync +from splitio.engine.impressions.impressions import Manager as ImpressionsManager +from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.engine.impressions.manager import Counter as ImpressionsCounter +from splitio.engine.impressions.unique_keys_tracker import UniqueKeysTracker, UniqueKeysTrackerAsync +from splitio.storage.inmemmory import EventStorage, ImpressionStorage, InMemoryTelemetryStorage, InMemoryEventStorageAsync, InMemoryImpressionStorageAsync +from splitio.storage.redis import ImpressionPipelinedStorage, EventStorage, RedisEventsStorage, RedisImpressionsStorage, RedisImpressionsStorageAsync, RedisEventsStorageAsync +from splitio.storage.adapters.redis import RedisAdapter, RedisAdapterAsync from splitio.models.impressions import Impression - +from splitio.models.telemetry import MethodExceptionsAndLatencies +from splitio.optional.loaders import asyncio class StandardRecorderTests(object): """StandardRecorderTests test cases.""" def test_standard_recorder(self, mocker): impressions = [ - Impression('k1', 'f1', 'on', 'l1', 123, None, None), - Impression('k1', 'f2', 'on', 'l1', 123, None, None) + Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None) ] impmanager = mocker.Mock(spec=ImpressionsManager) - impmanager.process_impressions.return_value = impressions + impmanager.process_impressions.return_value = impressions, 0, [ + (Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), None), + (Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None), None)], \ + [{"f": "f1", "ks": ["l1"]}, {"f": "f2", "ks": ["l1"]}], [('k1', 'f1'), ('k1', 'f2')] event = mocker.Mock(spec=EventStorage) impression = mocker.Mock(spec=ImpressionStorage) - recorder = StandardRecorder(impmanager, event, impression) - recorder.record_treatment_stats(impressions, 1, 'some') + telemetry_storage = mocker.Mock(spec=InMemoryTelemetryStorage) + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + listener = mocker.Mock(spec=ImpressionListenerWrapper) + + def record_latency(*args, **kwargs): + self.passed_args = args + + telemetry_storage.record_latency.side_effect = record_latency + + imp_counter = mocker.Mock(spec=ImpressionsCounter()) + unique_keys_tracker = mocker.Mock(spec=UniqueKeysTracker()) + recorder = StandardRecorder(impmanager, event, impression, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer(), + listener=listener, unique_keys_tracker=unique_keys_tracker, imp_counter=imp_counter) + recorder.record_treatment_stats(impressions, 1, MethodExceptionsAndLatencies.TREATMENT, 'get_treatment') assert recorder._impression_storage.put.mock_calls[0][1][0] == impressions + assert(self.passed_args[0] == MethodExceptionsAndLatencies.TREATMENT) + assert(self.passed_args[1] == 1) + assert listener.log_impression.mock_calls == [ + mocker.call(Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), None), + mocker.call(Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None), None) + ] + assert recorder._imp_counter.track.mock_calls == [mocker.call([{"f": "f1", "ks": ["l1"]}, {"f": "f2", "ks": ["l1"]}])] + assert recorder._unique_keys_tracker.track.mock_calls == [mocker.call('k1', 'f1'), mocker.call('k1', 'f2')] def test_pipelined_recorder(self, mocker): impressions = [ - Impression('k1', 'f1', 'on', 'l1', 123, None, None), - Impression('k1', 'f2', 'on', 'l1', 123, None, None) + Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None) ] redis = mocker.Mock(spec=RedisAdapter) - impmanager = mocker.Mock(spec=ImpressionsManager) - impmanager.process_impressions.return_value = impressions - event = mocker.Mock(spec=EventStorage) - impression = mocker.Mock(spec=ImpressionStorage) - recorder = PipelinedRecorder(redis, impmanager, event, impression) - recorder.record_treatment_stats(impressions, 1, 'some') - assert recorder._impression_storage.put.mock_calls[0][1][0] == impressions + def execute(): + return [] + redis().execute = execute - # TODO @matias.melograno Commented until we implement TelemetryV2 - # assert recorder._impression_storage.add_impressions_to_pipe.mock_calls[0][1][0] == impressions - # assert recorder._telemetry_storage.add_latency_to_pipe.mock_calls[0][1][0] == 'some' - # assert recorder._telemetry_storage.add_latency_to_pipe.mock_calls[0][1][1] == 1 + impmanager = mocker.Mock(spec=ImpressionsManager) + impmanager.process_impressions.return_value = impressions, 0, [ + (Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), None), + (Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None), None)], \ + [{"f": "f1", "ks": ["l1"]}, {"f": "f2", "ks": ["l1"]}], [('k1', 'f1'), ('k1', 'f2')] + event = mocker.Mock(spec=RedisEventsStorage) + impression = mocker.Mock(spec=RedisImpressionsStorage) + listener = mocker.Mock(spec=ImpressionListenerWrapper) + imp_counter = mocker.Mock(spec=ImpressionsCounter()) + unique_keys_tracker = mocker.Mock(spec=UniqueKeysTracker()) + recorder = PipelinedRecorder(redis, impmanager, event, impression, mocker.Mock(), + listener=listener, unique_keys_tracker=unique_keys_tracker, imp_counter=imp_counter) + recorder.record_treatment_stats(impressions, 1, MethodExceptionsAndLatencies.TREATMENT, 'get_treatment') + assert recorder._impression_storage.add_impressions_to_pipe.mock_calls[0][1][0] == impressions + assert recorder._telemetry_redis_storage.add_latency_to_pipe.mock_calls[0][1][0] == MethodExceptionsAndLatencies.TREATMENT + assert recorder._telemetry_redis_storage.add_latency_to_pipe.mock_calls[0][1][1] == 1 + assert listener.log_impression.mock_calls == [ + mocker.call(Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), None), + mocker.call(Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None), None) + ] + assert recorder._imp_counter.track.mock_calls == [mocker.call([{"f": "f1", "ks": ["l1"]}, {"f": "f2", "ks": ["l1"]}])] + assert recorder._unique_keys_tracker.track.mock_calls == [mocker.call('k1', 'f1'), mocker.call('k1', 'f2')] def test_sampled_recorder(self, mocker): impressions = [ - Impression('k1', 'f1', 'on', 'l1', 123, None, None), - Impression('k1', 'f2', 'on', 'l1', 123, None, None) + Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None) ] redis = mocker.Mock(spec=RedisAdapter) impmanager = mocker.Mock(spec=ImpressionsManager) - impmanager.process_impressions.return_value = impressions + impmanager.process_impressions.return_value = impressions, 0, [ + (Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), None), + (Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None), None) + ], [], [] + event = mocker.Mock(spec=EventStorage) impression = mocker.Mock(spec=ImpressionStorage) - recorder = PipelinedRecorder(redis, impmanager, event, impression, 0.5) + imp_counter = mocker.Mock(spec=ImpressionsCounter()) + unique_keys_tracker = mocker.Mock(spec=UniqueKeysTracker()) + recorder = PipelinedRecorder(redis, impmanager, event, impression, 0.5, mocker.Mock(), imp_counter=imp_counter, unique_keys_tracker=unique_keys_tracker) def put(x): return + recorder._impression_storage.put.side_effect = put + + for _ in range(100): + recorder.record_treatment_stats(impressions, 1, 'some', 'get_treatment') + print(recorder._impression_storage.put.call_count) + assert recorder._impression_storage.put.call_count < 80 + assert recorder._imp_counter.track.mock_calls == [] + assert recorder._unique_keys_tracker.track.mock_calls == [] + +class StandardRecorderAsyncTests(object): + """StandardRecorder async test cases.""" + + @pytest.mark.asyncio + async def test_standard_recorder(self, mocker): + impressions = [ + Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None) + ] + impmanager = mocker.Mock(spec=ImpressionsManager) + impmanager.process_impressions.return_value = impressions, 0, [ + (Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), {'att1': 'val'}), + (Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None), None)], \ + [{"f": "f1", "ks": ["l1"]}, {"f": "f2", "ks": ["l1"]}], [('k1', 'f1'), ('k1', 'f2')] + event = mocker.Mock(spec=InMemoryEventStorageAsync) + impression = mocker.Mock(spec=InMemoryImpressionStorageAsync) + telemetry_storage = mocker.Mock(spec=InMemoryTelemetryStorage) + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + listener = mocker.Mock(spec=ImpressionListenerWrapperAsync) + self.listener_impressions = [] + self.listener_attributes = [] + async def log_impression(impressions, attributes): + self.listener_impressions.append(impressions) + self.listener_attributes.append(attributes) + listener.log_impression = log_impression + + async def record_latency(*args, **kwargs): + self.passed_args = args + telemetry_storage.record_latency.side_effect = record_latency + + imp_counter = mocker.Mock(spec=ImpressionsCounter()) + unique_keys_tracker = mocker.Mock(spec=UniqueKeysTrackerAsync()) + recorder = StandardRecorderAsync(impmanager, event, impression, telemetry_producer.get_telemetry_evaluation_producer(), telemetry_producer.get_telemetry_runtime_producer(), + listener=listener, unique_keys_tracker=unique_keys_tracker, imp_counter=imp_counter) + self.impressions = [] + async def put(x): + self.impressions = x + return + recorder._impression_storage.put = put + + self.count = [] + def track(x): + self.count = x + recorder._imp_counter.track = track + + self.unique_keys = [] + async def track2(x, y): + self.unique_keys.append((x, y)) + recorder._unique_keys_tracker.track = track2 + + await recorder.record_treatment_stats(impressions, 1, MethodExceptionsAndLatencies.TREATMENT, 'get_treatment') + await asyncio.sleep(1) + + assert self.impressions == impressions + assert(self.passed_args[0] == MethodExceptionsAndLatencies.TREATMENT) + assert(self.passed_args[1] == 1) + assert self.listener_impressions == [ + Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None), + ] + assert self.listener_attributes == [{'att1': 'val'}, None] + assert self.count == [{"f": "f1", "ks": ["l1"]}, {"f": "f2", "ks": ["l1"]}] + assert self.unique_keys == [('k1', 'f1'), ('k1', 'f2')] + + @pytest.mark.asyncio + async def test_pipelined_recorder(self, mocker): + impressions = [ + Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None) + ] + redis = mocker.Mock(spec=RedisAdapterAsync) + async def execute(): + return [] + redis().execute = execute + impmanager = mocker.Mock(spec=ImpressionsManager) + impmanager.process_impressions.return_value = impressions, 0, [ + (Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), {'att1': 'val'}), + (Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None), None)], \ + [{"f": "f1", "ks": ["l1"]}, {"f": "f2", "ks": ["l1"]}], [('k1', 'f1'), ('k1', 'f2')] + event = mocker.Mock(spec=RedisEventsStorageAsync) + impression = mocker.Mock(spec=RedisImpressionsStorageAsync) + listener = mocker.Mock(spec=ImpressionListenerWrapperAsync) + self.listener_impressions = [] + self.listener_attributes = [] + async def log_impression(impressions, attributes): + self.listener_impressions.append(impressions) + self.listener_attributes.append(attributes) + listener.log_impression = log_impression + + imp_counter = mocker.Mock(spec=ImpressionsCounter()) + unique_keys_tracker = mocker.Mock(spec=UniqueKeysTrackerAsync()) + recorder = PipelinedRecorderAsync(redis, impmanager, event, impression, mocker.Mock(), + listener=listener, unique_keys_tracker=unique_keys_tracker, imp_counter=imp_counter) + self.count = [] + def track(x): + self.count = x + recorder._imp_counter.track = track + + self.unique_keys = [] + async def track2(x, y): + self.unique_keys.append((x, y)) + recorder._unique_keys_tracker.track = track2 + + await recorder.record_treatment_stats(impressions, 1, MethodExceptionsAndLatencies.TREATMENT, 'get_treatment') + await asyncio.sleep(.2) + assert recorder._impression_storage.add_impressions_to_pipe.mock_calls[0][1][0] == impressions + assert recorder._telemetry_redis_storage.add_latency_to_pipe.mock_calls[0][1][0] == MethodExceptionsAndLatencies.TREATMENT + assert recorder._telemetry_redis_storage.add_latency_to_pipe.mock_calls[0][1][1] == 1 + assert self.listener_impressions == [ + Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None), + ] + assert self.listener_attributes == [{'att1': 'val'}, None] + assert self.count == [{"f": "f1", "ks": ["l1"]}, {"f": "f2", "ks": ["l1"]}] + assert self.unique_keys == [('k1', 'f1'), ('k1', 'f2')] + + @pytest.mark.asyncio + async def test_sampled_recorder(self, mocker): + impressions = [ + Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), + Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None) + ] + redis = mocker.Mock(spec=RedisAdapterAsync) + impmanager = mocker.Mock(spec=ImpressionsManager) + impmanager.process_impressions.return_value = impressions, 0, [ + (Impression('k1', 'f1', 'on', 'l1', 123, None, None, None, None), None), + (Impression('k1', 'f2', 'on', 'l1', 123, None, None, None, None), None) + ], [], [] + event = mocker.Mock(spec=RedisEventsStorageAsync) + impression = mocker.Mock(spec=RedisImpressionsStorageAsync) + imp_counter = mocker.Mock(spec=ImpressionsCounter()) + unique_keys_tracker = mocker.Mock(spec=UniqueKeysTrackerAsync()) + recorder = PipelinedRecorderAsync(redis, impmanager, event, impression, 0.5, mocker.Mock(), + unique_keys_tracker=unique_keys_tracker, imp_counter=imp_counter) + self.count = [] + async def track(x): + self.count = x + recorder._imp_counter.track = track + + self.unique_keys = [] + async def track2(x, y): + self.unique_keys.append((x, y)) + recorder._unique_keys_tracker.track = track2 + + async def put(x): + return recorder._impression_storage.put.side_effect = put for _ in range(100): - recorder.record_treatment_stats(impressions, 1, 'some') + await recorder.record_treatment_stats(impressions, 1, 'some', 'get_treatment') print(recorder._impression_storage.put.call_count) assert recorder._impression_storage.put.call_count < 80 + assert self.count == [] + assert self.unique_keys == [] diff --git a/tests/storage/adapters/test_cache_trait.py b/tests/storage/adapters/test_cache_trait.py index 15f3b13a..5643cb32 100644 --- a/tests/storage/adapters/test_cache_trait.py +++ b/tests/storage/adapters/test_cache_trait.py @@ -6,6 +6,7 @@ import pytest from splitio.storage.adapters import cache_trait +from splitio.optional.loaders import asyncio class CacheTraitTests(object): """Cache trait test cases.""" @@ -130,3 +131,11 @@ def test_decorate(self, mocker): assert cache_trait.decorate(key_func, 0, 10)(user_func) is user_func assert cache_trait.decorate(key_func, 10, 0)(user_func) is user_func assert cache_trait.decorate(key_func, 0, 0)(user_func) is user_func + + @pytest.mark.asyncio + async def test_async_add_and_get_key(self, mocker): + cache = cache_trait.LocalMemoryCacheAsync(None, None, 1, 1) + await cache.add_key('split', {'split_name': 'split'}) + assert await cache.get_key('split') == {'split_name': 'split'} + await asyncio.sleep(1) + assert await cache.get_key('split') == None diff --git a/tests/storage/adapters/test_redis_adapter.py b/tests/storage/adapters/test_redis_adapter.py index d2bf686f..9888c853 100644 --- a/tests/storage/adapters/test_redis_adapter.py +++ b/tests/storage/adapters/test_redis_adapter.py @@ -1,7 +1,9 @@ """Redis storage adapter test module.""" import pytest +from redis.asyncio.client import Redis as aioredis from splitio.storage.adapters import redis +from splitio.storage.adapters.redis import _build_default_client_async, _build_sentinel_client_async from redis import StrictRedis, Redis from redis.sentinel import Sentinel @@ -55,6 +57,12 @@ def test_forwarding(self, mocker): adapter.incr('key1') assert redis_mock.incr.mock_calls[0] == mocker.call('some_prefix.key1', 1) + adapter.hincrby('key1', 'name1') + assert redis_mock.hincrby.mock_calls[0] == mocker.call('some_prefix.key1', 'name1', 1) + + adapter.hincrby('key1', 'name1', 5) + assert redis_mock.hincrby.mock_calls[1] == mocker.call('some_prefix.key1', 'name1', 5) + adapter.getset('key1', 'new_value') assert redis_mock.getset.mock_calls[0] == mocker.call('some_prefix.key1', 'new_value') @@ -81,6 +89,7 @@ def test_adapter_building(self, mocker): 'redisHost': 'some_host', 'redisPort': 1234, 'redisDb': 0, + 'redisUsername': 'redis_user', 'redisPassword': 'some_password', 'redisSocketTimeout': 123, 'redisSocketConnectTimeout': 456, @@ -90,7 +99,6 @@ def test_adapter_building(self, mocker): 'redisUnixSocketPath': '/tmp/socket', 'redisEncoding': 'utf-8', 'redisEncodingErrors': 'strict', - 'redisErrors': 'abc', 'redisDecodeResponses': True, 'redisRetryOnTimeout': True, 'redisSsl': True, @@ -107,6 +115,7 @@ def test_adapter_building(self, mocker): host='some_host', port=1234, db=0, + username='redis_user', password='some_password', socket_timeout=123, socket_connect_timeout=456, @@ -116,7 +125,6 @@ def test_adapter_building(self, mocker): unix_socket_path='/tmp/socket', encoding='utf-8', encoding_errors='strict', - errors='abc', decode_responses=True, retry_on_timeout=True, ssl=True, @@ -131,6 +139,7 @@ def test_adapter_building(self, mocker): 'redisSentinels': [('123.123.123.123', 1), ('456.456.456.456', 2), ('789.789.789.789', 3)], 'redisMasterService': 'some_master', 'redisDb': 0, + 'redisUsername': 'redis_user', 'redisPassword': 'some_password', 'redisSocketTimeout': 123, 'redisSocketConnectTimeout': 456, @@ -140,7 +149,6 @@ def test_adapter_building(self, mocker): 'redisUnixSocketPath': '/tmp/socket', 'redisEncoding': 'utf-8', 'redisEncodingErrors': 'strict', - 'redisErrors': 'abc', 'redisDecodeResponses': True, 'redisRetryOnTimeout': True, 'redisSsl': False, @@ -156,6 +164,7 @@ def test_adapter_building(self, mocker): assert sentinel_mock.mock_calls[0] == mocker.call( [('123.123.123.123', 1), ('456.456.456.456', 2), ('789.789.789.789', 3)], db=0, + username='redis_user', password='some_password', socket_timeout=123, socket_connect_timeout=456, @@ -178,6 +187,366 @@ def test_sentinel_ssl_fails(self): }) +class RedisStorageAdapterAsyncTests(object): + """Redis storage adapter test cases.""" + + @pytest.mark.asyncio + async def test_forwarding(self, mocker): + """Test that all redis functions forward prefix appropriately.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + + self.arg = None + async def keys(sel, args): + self.arg = args + return ['some_prefix.key1', 'some_prefix.key2'] + mocker.patch('redis.asyncio.client.Redis.keys', new=keys) + await adapter.keys('*') + assert self.arg == 'some_prefix.*' + + self.key = None + self.value = None + async def set(sel, key, value): + self.key = key + self.value = value + mocker.patch('redis.asyncio.client.Redis.set', new=set) + await adapter.set('key1', 'value1') + assert self.key == 'some_prefix.key1' + assert self.value == 'value1' + + self.key = None + async def get(sel, key): + self.key = key + return 'value1' + mocker.patch('redis.asyncio.client.Redis.get', new=get) + await adapter.get('some_key') + assert self.key == 'some_prefix.some_key' + + self.key = None + self.value = None + self.exp = None + async def setex(sel, key, exp, value): + self.key = key + self.value = value + self.exp = exp + mocker.patch('redis.asyncio.client.Redis.setex', new=setex) + await adapter.setex('some_key', 123, 'some_value') + assert self.key == 'some_prefix.some_key' + assert self.exp == 123 + assert self.value == 'some_value' + + self.key = None + async def delete(sel, key): + self.key = key + mocker.patch('redis.asyncio.client.Redis.delete', new=delete) + await adapter.delete('some_key') + assert self.key == 'some_prefix.some_key' + + self.keys = None + async def mget(sel, keys): + self.keys = keys + return ['value1', 'value2', 'value3'] + mocker.patch('redis.asyncio.client.Redis.mget', new=mget) + await adapter.mget(['key1', 'key2', 'key3']) + assert self.keys == ['some_prefix.key1', 'some_prefix.key2', 'some_prefix.key3'] + + self.key = None + self.value = None + self.value2 = None + async def sadd(sel, key, value, value2): + self.key = key + self.value = value + self.value2 = value2 + mocker.patch('redis.asyncio.client.Redis.sadd', new=sadd) + await adapter.sadd('s1', 'value1', 'value2') + assert self.key == 'some_prefix.s1' + assert self.value == 'value1' + assert self.value2 == 'value2' + + self.key = None + self.value = None + self.value2 = None + async def srem(sel, key, value, value2): + self.key = key + self.value = value + self.value2 = value2 + mocker.patch('redis.asyncio.client.Redis.srem', new=srem) + await adapter.srem('s1', 'value1', 'value2') + assert self.key == 'some_prefix.s1' + assert self.value == 'value1' + assert self.value2 == 'value2' + + self.key = None + self.value = None + async def sismember(sel, key, value): + self.key = key + self.value = value + mocker.patch('redis.asyncio.client.Redis.sismember', new=sismember) + await adapter.sismember('s1', 'value1') + assert self.key == 'some_prefix.s1' + assert self.value == 'value1' + + self.key = None + self.key2 = None + self.key3 = None + self.script = None + self.value = None + async def eval(sel, script, value, key, key2, key3): + self.key = key + self.key2 = key2 + self.key3 = key3 + self.script = script + self.value = value + mocker.patch('redis.asyncio.client.Redis.eval', new=eval) + await adapter.eval('script', 3, 'key1', 'key2', 'key3') + assert self.script == 'script' + assert self.value == 3 + assert self.key == 'some_prefix.key1' + assert self.key2 == 'some_prefix.key2' + assert self.key3 == 'some_prefix.key3' + + self.key = None + self.value = None + self.name = None + async def hset(sel, key, name, value): + self.key = key + self.value = value + self.name = name + mocker.patch('redis.asyncio.client.Redis.hset', new=hset) + await adapter.hset('key1', 'name', 'value') + assert self.key == 'some_prefix.key1' + assert self.name == 'name' + assert self.value == 'value' + + self.key = None + self.name = None + async def hget(sel, key, name): + self.key = key + self.name = name + mocker.patch('redis.asyncio.client.Redis.hget', new=hget) + await adapter.hget('key1', 'name') + assert self.key == 'some_prefix.key1' + assert self.name == 'name' + + self.key = None + self.value = None + async def incr(sel, key, value): + self.key = key + self.value = value + mocker.patch('redis.asyncio.client.Redis.incr', new=incr) + await adapter.incr('key1') + assert self.key == 'some_prefix.key1' + assert self.value == 1 + + self.key = None + self.value = None + self.name = None + async def hincrby(sel, key, name, value): + self.key = key + self.value = value + self.name = name + mocker.patch('redis.asyncio.client.Redis.hincrby', new=hincrby) + await adapter.hincrby('key1', 'name1') + assert self.key == 'some_prefix.key1' + assert self.name == 'name1' + assert self.value == 1 + + await adapter.hincrby('key1', 'name1', 5) + assert self.key == 'some_prefix.key1' + assert self.name == 'name1' + assert self.value == 5 + + self.key = None + self.value = None + async def getset(sel, key, value): + self.key = key + self.value = value + mocker.patch('redis.asyncio.client.Redis.getset', new=getset) + await adapter.getset('key1', 'new_value') + assert self.key == 'some_prefix.key1' + assert self.value == 'new_value' + + self.key = None + self.value = None + self.value2 = None + async def rpush(sel, key, value, value2): + self.key = key + self.value = value + self.value2 = value2 + mocker.patch('redis.asyncio.client.Redis.rpush', new=rpush) + await adapter.rpush('key1', 'value1', 'value2') + assert self.key == 'some_prefix.key1' + assert self.value == 'value1' + assert self.value2 == 'value2' + + self.key = None + self.exp = None + async def expire(sel, key, exp): + self.key = key + self.exp = exp + mocker.patch('redis.asyncio.client.Redis.expire', new=expire) + await adapter.expire('key1', 10) + assert self.key == 'some_prefix.key1' + assert self.exp == 10 + + self.key = None + async def rpop(sel, key): + self.key = key + mocker.patch('redis.asyncio.client.Redis.rpop', new=rpop) + await adapter.rpop('key1') + assert self.key == 'some_prefix.key1' + + self.key = None + async def ttl(sel, key): + self.key = key + mocker.patch('redis.asyncio.client.Redis.ttl', new=ttl) + await adapter.ttl('key1') + assert self.key == 'some_prefix.key1' + + @pytest.mark.asyncio + async def test_adapter_building(self, mocker): + """Test buildin different types of client according to parameters received.""" + + config = { + 'redisHost': 'some_host', + 'redisPort': 1234, + 'redisDb': 0, + 'redisPassword': 'some_password', + 'redisSocketTimeout': 123, + 'redisSocketKeepalive': 789, + 'redisSocketKeepaliveOptions': 10, + 'redisUnixSocketPath': '/tmp/socket', + 'redisEncoding': 'utf-8', + 'redisEncodingErrors': 'strict', + 'redisDecodeResponses': True, + 'redisRetryOnTimeout': True, + 'redisSsl': True, + 'redisSslKeyfile': '/ssl.cert', + 'redisSslCertfile': '/ssl2.cert', + 'redisSslCertReqs': 'abc', + 'redisSslCaCerts': 'def', + 'redisMaxConnections': 5, + 'redisPrefix': 'some_prefix' + } + + def redis_init(se, connection_pool, + socket_connect_timeout, + socket_keepalive, + socket_keepalive_options, + unix_socket_path, + encoding_errors, + retry_on_timeout, + ssl, + ssl_keyfile, + ssl_certfile, + ssl_cert_reqs, + ssl_ca_certs): + self.connection_pool=connection_pool + self.socket_connect_timeout=socket_connect_timeout + self.socket_keepalive=socket_keepalive + self.socket_keepalive_options=socket_keepalive_options + self.unix_socket_path=unix_socket_path + self.encoding_errors=encoding_errors + self.retry_on_timeout=retry_on_timeout + self.ssl=ssl + self.ssl_keyfile=ssl_keyfile + self.ssl_certfile=ssl_certfile + self.ssl_cert_reqs=ssl_cert_reqs + self.ssl_ca_certs=ssl_ca_certs + mocker.patch('redis.asyncio.client.Redis.__init__', new=redis_init) + + redis_mock = await _build_default_client_async(config) + + assert self.connection_pool.connection_kwargs['host'] == 'some_host' + assert self.connection_pool.connection_kwargs['port'] == 1234 + assert self.connection_pool.connection_kwargs['db'] == 0 + assert self.connection_pool.connection_kwargs['password'] == 'some_password' + assert self.connection_pool.connection_kwargs['encoding'] == 'utf-8' + assert self.connection_pool.connection_kwargs['decode_responses'] == True + + assert self.socket_keepalive == 789 + assert self.socket_keepalive_options == 10 + assert self.unix_socket_path == '/tmp/socket' + assert self.encoding_errors == 'strict' + assert self.retry_on_timeout == True + assert self.ssl == True + assert self.ssl_keyfile == '/ssl.cert' + assert self.ssl_certfile == '/ssl2.cert' + assert self.ssl_cert_reqs == 'abc' + assert self.ssl_ca_certs == 'def' + + def create_sentinel(se, + sentinels, + db, + password, + encoding, + max_connections, + encoding_errors, + decode_responses, + connection_pool, + socket_connect_timeout): + self.sentinels=sentinels + self.db=db + self.password=password + self.encoding=encoding + self.max_connections=max_connections + self.encoding_errors=encoding_errors, + self.decode_responses=decode_responses, + self.connection_pool=connection_pool, + self.socket_connect_timeout=socket_connect_timeout + mocker.patch('redis.asyncio.sentinel.Sentinel.__init__', new=create_sentinel) + + def master_for(se, + master_service, + socket_timeout, + socket_keepalive, + socket_keepalive_options, + encoding_errors, + retry_on_timeout, + ssl): + self.master_service = master_service, + self.socket_timeout = socket_timeout, + self.socket_keepalive = socket_keepalive, + self.socket_keepalive_options = socket_keepalive_options, + self.encoding_errors = encoding_errors, + self.retry_on_timeout = retry_on_timeout, + self.ssl = ssl + mocker.patch('redis.asyncio.sentinel.Sentinel.master_for', new=master_for) + + config = { + 'redisSentinels': [('123.123.123.123', 1), ('456.456.456.456', 2), ('789.789.789.789', 3)], + 'redisMasterService': 'some_master', + 'redisDb': 0, + 'redisPassword': 'some_password', + 'redisSocketTimeout': 123, + 'redisSocketConnectTimeout': 456, + 'redisSocketKeepalive': 789, + 'redisSocketKeepaliveOptions': 10, + 'redisConnectionPool': 20, + 'redisUnixSocketPath': '/tmp/socket', + 'redisEncoding': 'utf-8', + 'redisEncodingErrors': 'strict', + 'redisDecodeResponses': True, + 'redisRetryOnTimeout': True, + 'redisSsl': False, + 'redisMaxConnections': 5, + 'redisPrefix': 'some_prefix' + } + await _build_sentinel_client_async(config) + assert self.sentinels == [('123.123.123.123', 1), ('456.456.456.456', 2), ('789.789.789.789', 3)] + assert self.db == 0 + assert self.password == 'some_password' + assert self.encoding == 'utf-8' + assert self.max_connections == 5 + assert self.ssl == False + assert self.master_service == ('some_master',) + assert self.socket_timeout == (123,) + assert self.socket_keepalive == (789,) + assert self.socket_keepalive_options == (10,) + assert self.encoding_errors == ('strict',) + assert self.retry_on_timeout == (True,) + + class RedisPipelineAdapterTests(object): """Redis pipelined adapter test cases.""" @@ -194,3 +563,68 @@ def test_forwarding(self, mocker): adapter.incr('key1') assert redis_mock_2.incr.mock_calls[0] == mocker.call('some_prefix.key1', 1) + + adapter.hincrby('key1', 'name1') + assert redis_mock_2.hincrby.mock_calls[0] == mocker.call('some_prefix.key1', 'name1', 1) + + adapter.hincrby('key1', 'name1', 5) + assert redis_mock_2.hincrby.mock_calls[1] == mocker.call('some_prefix.key1', 'name1', 5) + + +class RedisPipelineAdapterAsyncTests(object): + """Redis pipelined adapter test cases.""" + + @pytest.mark.asyncio + async def test_forwarding(self, mocker): + """Test that all redis functions forward prefix appropriately.""" + redis_mock = await aioredis.from_url("redis://localhost") + prefix_helper = redis.PrefixHelper('some_prefix') + adapter = redis.RedisPipelineAdapterAsync(redis_mock, prefix_helper) + + self.key = None + self.value = None + self.value2 = None + def rpush(sel, key, value, value2): + self.key = key + self.value = value + self.value2 = value2 + mocker.patch('redis.asyncio.client.Pipeline.rpush', new=rpush) + adapter.rpush('key1', 'value1', 'value2') + assert self.key == 'some_prefix.key1' + assert self.value == 'value1' + assert self.value2 == 'value2' + + self.key = None + self.value = None + def incr(sel, key, value): + self.key = key + self.value = value + mocker.patch('redis.asyncio.client.Pipeline.incr', new=incr) + adapter.incr('key1') + assert self.key == 'some_prefix.key1' + assert self.value == 1 + + self.key = None + self.value = None + self.name = None + def hincrby(sel, key, name, value): + self.key = key + self.value = value + self.name = name + mocker.patch('redis.asyncio.client.Pipeline.hincrby', new=hincrby) + adapter.hincrby('key1', 'name1') + assert self.key == 'some_prefix.key1' + assert self.name == 'name1' + assert self.value == 1 + + adapter.hincrby('key1', 'name1', 5) + assert self.key == 'some_prefix.key1' + assert self.name == 'name1' + assert self.value == 5 + + self.called = False + async def execute(*_): + self.called = True + mocker.patch('redis.asyncio.client.Pipeline.execute', new=execute) + await adapter.execute() + assert self.called diff --git a/tests/storage/test_flag_sets.py b/tests/storage/test_flag_sets.py new file mode 100644 index 00000000..995117cb --- /dev/null +++ b/tests/storage/test_flag_sets.py @@ -0,0 +1,65 @@ +import pytest + +from splitio.storage import FlagSetsFilter +from splitio.storage.inmemmory import FlagSets + +class FlagSetsFilterTests(object): + """Flag sets filter storage tests.""" + def test_without_initial_set(self): + flag_set = FlagSets() + assert flag_set.sets_feature_flag_map == {} + + flag_set._add_flag_set('set1') + assert flag_set.get_flag_set('set1') == set({}) + assert flag_set.flag_set_exist('set1') == True + assert flag_set.flag_set_exist('set2') == False + + flag_set.add_feature_flag_to_flag_set('set1', 'split1') + assert flag_set.get_flag_set('set1') == {'split1'} + flag_set.add_feature_flag_to_flag_set('set1', 'split2') + assert flag_set.get_flag_set('set1') == {'split1', 'split2'} + flag_set.remove_feature_flag_to_flag_set('set1', 'split1') + assert flag_set.get_flag_set('set1') == {'split2'} + flag_set._remove_flag_set('set2') + assert flag_set.sets_feature_flag_map == {'set1': set({'split2'})} + flag_set._remove_flag_set('set1') + assert flag_set.sets_feature_flag_map == {} + assert flag_set.flag_set_exist('set1') == False + + def test_with_initial_set(self): + flag_set = FlagSets(['set1', 'set2']) + assert flag_set.sets_feature_flag_map == {'set1': set(), 'set2': set()} + + flag_set._add_flag_set('set1') + assert flag_set.get_flag_set('set1') == set({}) + assert flag_set.flag_set_exist('set1') == True + assert flag_set.flag_set_exist('set2') == True + + flag_set.add_feature_flag_to_flag_set('set1', 'split1') + assert flag_set.get_flag_set('set1') == {'split1'} + flag_set.add_feature_flag_to_flag_set('set1', 'split2') + assert flag_set.get_flag_set('set1') == {'split1', 'split2'} + flag_set.remove_feature_flag_to_flag_set('set1', 'split1') + assert flag_set.get_flag_set('set1') == {'split2'} + flag_set._remove_flag_set('set2') + assert flag_set.sets_feature_flag_map == {'set1': set({'split2'})} + flag_set._remove_flag_set('set1') + assert flag_set.sets_feature_flag_map == {} + assert flag_set.flag_set_exist('set1') == False + + def test_flag_set_filter(self): + flag_set_filter = FlagSetsFilter() + assert flag_set_filter.flag_sets == set() + assert not flag_set_filter.should_filter + + flag_set_filter = FlagSetsFilter(['set1', 'set2']) + assert flag_set_filter.flag_sets == set({'set1', 'set2'}) + assert flag_set_filter.should_filter + assert flag_set_filter.intersect(set({'set1', 'set2'})) + assert flag_set_filter.intersect(set({'set1', 'set2', 'set5'})) + assert not flag_set_filter.intersect(set({'set4'})) + assert not flag_set_filter.set_exist('set4') + assert flag_set_filter.set_exist('set1') + + flag_set_filter = FlagSetsFilter(['set5', 'set2', 'set6', 'set1']) + assert flag_set_filter.sorted_flag_sets == ['set1', 'set2', 'set5', 'set6'] \ No newline at end of file diff --git a/tests/storage/test_inmemory_storage.py b/tests/storage/test_inmemory_storage.py index 8594a443..d46980aa 100644 --- a/tests/storage/test_inmemory_storage.py +++ b/tests/storage/test_inmemory_storage.py @@ -1,33 +1,92 @@ """In-Memory storage test module.""" # pylint: disable=no-self-use +import random +import pytest +import copy +import queue +import asyncio + from splitio.models.splits import Split from splitio.models.segments import Segment from splitio.models.impressions import Impression from splitio.models.events import Event, EventWrapper - -from splitio.storage.inmemmory import InMemorySplitStorage, InMemorySegmentStorage, \ - InMemoryImpressionStorage, InMemoryEventStorage - +from splitio.models.events import SdkInternalEvent +import splitio.models.telemetry as ModelTelemetry +from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.events.events_metadata import SdkEventType +from splitio.storage.inmemmory import InMemorySplitStorage, InMemorySegmentStorage, InMemorySegmentStorageAsync, InMemorySplitStorageAsync, \ + InMemoryImpressionStorage, InMemoryEventStorage, InMemoryTelemetryStorage, InMemoryImpressionStorageAsync, InMemoryEventStorageAsync, \ + InMemoryTelemetryStorageAsync, FlagSets, InMemoryRuleBasedSegmentStorage, InMemoryRuleBasedSegmentStorageAsync +from splitio.models.rule_based_segments import RuleBasedSegment +from splitio.models import rule_based_segments + +class FlagSetsFilterTests(object): + """Flag sets filter storage tests.""" + def test_without_initial_set(self): + flag_set = FlagSets() + assert flag_set.sets_feature_flag_map == {} + + flag_set._add_flag_set('set1') + assert flag_set.get_flag_set('set1') == set({}) + assert flag_set.flag_set_exist('set1') == True + assert flag_set.flag_set_exist('set2') == False + + flag_set.add_feature_flag_to_flag_set('set1', 'split1') + assert flag_set.get_flag_set('set1') == {'split1'} + flag_set.add_feature_flag_to_flag_set('set1', 'split2') + assert flag_set.get_flag_set('set1') == {'split1', 'split2'} + flag_set.remove_feature_flag_to_flag_set('set1', 'split1') + assert flag_set.get_flag_set('set1') == {'split2'} + flag_set._remove_flag_set('set2') + assert flag_set.sets_feature_flag_map == {'set1': set({'split2'})} + flag_set._remove_flag_set('set1') + assert flag_set.sets_feature_flag_map == {} + assert flag_set.flag_set_exist('set1') == False + + def test_with_initial_set(self): + flag_set = FlagSets(['set1', 'set2']) + assert flag_set.sets_feature_flag_map == {'set1': set(), 'set2': set()} + + flag_set._add_flag_set('set1') + assert flag_set.get_flag_set('set1') == set({}) + assert flag_set.flag_set_exist('set1') == True + assert flag_set.flag_set_exist('set2') == True + + flag_set.add_feature_flag_to_flag_set('set1', 'split1') + assert flag_set.get_flag_set('set1') == {'split1'} + flag_set.add_feature_flag_to_flag_set('set1', 'split2') + assert flag_set.get_flag_set('set1') == {'split1', 'split2'} + flag_set.remove_feature_flag_to_flag_set('set1', 'split1') + assert flag_set.get_flag_set('set1') == {'split2'} + flag_set._remove_flag_set('set2') + assert flag_set.sets_feature_flag_map == {'set1': set({'split2'})} + flag_set._remove_flag_set('set1') + assert flag_set.sets_feature_flag_map == {} + assert flag_set.flag_set_exist('set1') == False class InMemorySplitStorageTests(object): """In memory split storage test cases.""" def test_storing_retrieving_splits(self, mocker): """Test storing and retrieving splits works.""" - storage = InMemorySplitStorage() + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) split = mocker.Mock(spec=Split) name_property = mocker.PropertyMock() name_property.return_value = 'some_split' type(split).name = name_property + sets_property = mocker.PropertyMock() + sets_property.return_value = ['set_1'] + type(split).sets = sets_property - storage.put(split) + storage.update([split], [], -1) assert storage.get('some_split') == split assert storage.get_split_names() == ['some_split'] assert storage.get_all_splits() == [split] assert storage.get('nonexistant_split') is None - storage.remove('some_split') + storage.update([], ['some_split'], -1) assert storage.get('some_split') is None def test_get_splits(self, mocker): @@ -36,26 +95,34 @@ def test_get_splits(self, mocker): name1_prop = mocker.PropertyMock() name1_prop.return_value = 'split1' type(split1).name = name1_prop + sets_property = mocker.PropertyMock() + sets_property.return_value = ['set_1'] + type(split1).sets = sets_property + split2 = mocker.Mock() name2_prop = mocker.PropertyMock() name2_prop.return_value = 'split2' type(split2).name = name2_prop + type(split2).sets = sets_property - storage = InMemorySplitStorage() - storage.put(split1) - storage.put(split2) + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) + storage.update([split1, split2], [], -1) splits = storage.fetch_many(['split1', 'split2', 'split3']) assert len(splits) == 3 assert splits['split1'].name == 'split1' + assert splits['split1'].sets == ['set_1'] assert splits['split2'].name == 'split2' + assert splits['split2'].sets == ['set_1'] assert 'split3' in splits def test_store_get_changenumber(self): """Test that storing and retrieving change numbers works.""" - storage = InMemorySplitStorage() + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) assert storage.get_change_number() == -1 - storage.set_change_number(5) + storage.update([], [], 5) assert storage.get_change_number() == 5 def test_get_split_names(self, mocker): @@ -64,14 +131,19 @@ def test_get_split_names(self, mocker): name1_prop = mocker.PropertyMock() name1_prop.return_value = 'split1' type(split1).name = name1_prop + sets_property = mocker.PropertyMock() + sets_property.return_value = ['set_1'] + type(split1).sets = sets_property + split2 = mocker.Mock() name2_prop = mocker.PropertyMock() name2_prop.return_value = 'split2' type(split2).name = name2_prop + type(split2).sets = sets_property - storage = InMemorySplitStorage() - storage.put(split1) - storage.put(split2) + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) + storage.update([split1, split2], [], -1) assert set(storage.get_split_names()) == set(['split1', 'split2']) @@ -81,14 +153,19 @@ def test_get_all_splits(self, mocker): name1_prop = mocker.PropertyMock() name1_prop.return_value = 'split1' type(split1).name = name1_prop + sets_property = mocker.PropertyMock() + sets_property.return_value = ['set_1'] + type(split1).sets = sets_property + split2 = mocker.Mock() name2_prop = mocker.PropertyMock() name2_prop.return_value = 'split2' type(split2).name = name2_prop + type(split2).sets = sets_property - storage = InMemorySplitStorage() - storage.put(split1) - storage.put(split2) + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) + storage.update([split1, split2], [], -1) all_splits = storage.get_all_splits() assert next(s for s in all_splits if s.name == 'split1') @@ -115,72 +192,84 @@ def test_is_valid_traffic_type(self, mocker): type(split1).traffic_type_name = tt_user type(split2).traffic_type_name = tt_account type(split3).traffic_type_name = tt_user + sets_property = mocker.PropertyMock() + sets_property.return_value = [] + type(split1).sets = sets_property + type(split2).sets = sets_property + type(split3).sets = sets_property - storage = InMemorySplitStorage() + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) - storage.put(split1) + storage.update([split1], [], -1) assert storage.is_valid_traffic_type('user') is True assert storage.is_valid_traffic_type('account') is False - storage.put(split2) + storage.update([split2], [], -1) assert storage.is_valid_traffic_type('user') is True assert storage.is_valid_traffic_type('account') is True - storage.put(split3) + storage.update([split3], [], -1) assert storage.is_valid_traffic_type('user') is True assert storage.is_valid_traffic_type('account') is True - storage.remove('split1') + storage.update([], ['split1'], -1) assert storage.is_valid_traffic_type('user') is True assert storage.is_valid_traffic_type('account') is True - storage.remove('split2') + storage.update([], ['split2'], -1) assert storage.is_valid_traffic_type('user') is True assert storage.is_valid_traffic_type('account') is False - storage.remove('split3') + storage.update([], ['split3'], -1) assert storage.is_valid_traffic_type('user') is False assert storage.is_valid_traffic_type('account') is False def test_traffic_type_inc_dec_logic(self, mocker): """Test that adding/removing split, handles traffic types correctly.""" - storage = InMemorySplitStorage() + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) split1 = mocker.Mock() name1_prop = mocker.PropertyMock() name1_prop.return_value = 'split1' type(split1).name = name1_prop - split2 = mocker.Mock() name2_prop = mocker.PropertyMock() name2_prop.return_value = 'split1' type(split2).name = name2_prop + sets_property = mocker.PropertyMock() + sets_property.return_value = None + type(split1).sets = sets_property + type(split2).sets = sets_property tt_user = mocker.PropertyMock() tt_user.return_value = 'user' - tt_account = mocker.PropertyMock() tt_account.return_value = 'account' - + sets_property = mocker.PropertyMock() + sets_property.return_value = ['set_1'] + type(split1).sets = sets_property + type(split2).sets = sets_property type(split1).traffic_type_name = tt_user type(split2).traffic_type_name = tt_account - storage.put(split1) + storage.update([split1], [], -1) assert storage.is_valid_traffic_type('user') is True assert storage.is_valid_traffic_type('account') is False - storage.put(split2) + storage.update([split2], [], -1) assert storage.is_valid_traffic_type('user') is False assert storage.is_valid_traffic_type('account') is True def test_kill_locally(self): """Test kill local.""" - storage = InMemorySplitStorage() + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) split = Split('some_split', 123456789, False, 'some', 'traffic_type', 'ACTIVE', 1) - storage.put(split) - storage.set_change_number(1) + storage.update([split], [], 1) storage.kill_locally('test', 'default_treatment', 2) assert storage.get('test') is None @@ -193,13 +282,469 @@ def test_kill_locally(self): storage.kill_locally('some_split', 'default_treatment', 3) assert storage.get('some_split').change_number == 3 + def test_flag_sets_with_config_sets(self): + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue, ['set10', 'set02', 'set05']) + assert storage.flag_set_filter.flag_sets == {'set10', 'set02', 'set05'} + assert storage.flag_set_filter.should_filter + + assert storage.flag_set.sets_feature_flag_map == {'set10': set(), 'set02': set(), 'set05': set()} + + split1 = Split('split1', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set10', 'set02']) + split2 = Split('split2', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set05', 'set02']) + split3 = Split('split3', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set04', 'set05']) + storage.update([split1], [], 1) + assert storage.get_feature_flags_by_sets(['set10']) == ['split1'] + assert storage.get_feature_flags_by_sets(['set02']) == ['split1'] + assert storage.get_feature_flags_by_sets(['set02', 'set10']) == ['split1'] + assert storage.is_flag_set_exist('set10') + assert storage.is_flag_set_exist('set02') + assert not storage.is_flag_set_exist('set03') + + storage.update([split2], [], 1) + assert storage.get_feature_flags_by_sets(['set05']) == ['split2'] + assert sorted(storage.get_feature_flags_by_sets(['set02', 'set05'])) == ['split1', 'split2'] + assert storage.is_flag_set_exist('set05') + + storage.update([], [split2.name], 1) + assert storage.is_flag_set_exist('set05') + assert storage.get_feature_flags_by_sets(['set02']) == ['split1'] + assert storage.get_feature_flags_by_sets(['set05']) == [] + + split1 = Split('split1', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set02']) + storage.update([split1], [], 1) + assert storage.is_flag_set_exist('set10') + assert storage.get_feature_flags_by_sets(['set02']) == ['split1'] + + storage.update([], [split1.name], 1) + assert storage.get_feature_flags_by_sets(['set02']) == [] + assert storage.flag_set.sets_feature_flag_map == {'set10': set(), 'set02': set(), 'set05': set()} + + storage.update([split3], [], 1) + assert storage.get_feature_flags_by_sets(['set05']) == ['split3'] + assert not storage.is_flag_set_exist('set04') + + def test_flag_sets_withut_config_sets(self): + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) + assert storage.flag_set_filter.flag_sets == set({}) + assert not storage.flag_set_filter.should_filter + + assert storage.flag_set.sets_feature_flag_map == {} + + split1 = Split('split1', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set10', 'set02']) + split2 = Split('split2', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set05', 'set02']) + split3 = Split('split3', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set04', 'set05']) + storage.update([split1], [], 1) + assert storage.get_feature_flags_by_sets(['set10']) == ['split1'] + assert storage.get_feature_flags_by_sets(['set02']) == ['split1'] + assert storage.is_flag_set_exist('set10') + assert storage.is_flag_set_exist('set02') + assert not storage.is_flag_set_exist('set03') + + storage.update([split2], [], 1) + assert storage.get_feature_flags_by_sets(['set05']) == ['split2'] + assert sorted(storage.get_feature_flags_by_sets(['set02', 'set05'])) == ['split1', 'split2'] + assert storage.is_flag_set_exist('set05') + + storage.update([], [split2.name], 1) + assert not storage.is_flag_set_exist('set05') + assert storage.get_feature_flags_by_sets(['set02']) == ['split1'] + + split1 = Split('split1', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set02']) + storage.update([split1], [], 1) + assert not storage.is_flag_set_exist('set10') + assert storage.get_feature_flags_by_sets(['set02']) == ['split1'] + + storage.update([], [split1.name], 1) + assert storage.get_feature_flags_by_sets(['set02']) == [] + assert storage.flag_set.sets_feature_flag_map == {} + + storage.update([split3], [], 1) + assert storage.get_feature_flags_by_sets(['set05']) == ['split3'] + assert storage.get_feature_flags_by_sets(['set04', 'set05']) == ['split3'] + + def test_internal_event_notification(self, mocker): + """Test storing and retrieving splits works.""" + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) + + split = mocker.Mock(spec=Split) + name_property = mocker.PropertyMock() + name_property.return_value = 'some_split' + type(split).name = name_property + sets_property = mocker.PropertyMock() + sets_property.return_value = ['set_1'] + type(split).sets = sets_property + + storage.update([split], [], -1) + assert storage.get('some_split') == split + assert storage.get_split_names() == ['some_split'] + assert storage.get_all_splits() == [split] + event = events_queue.get() + assert event.internal_event == SdkInternalEvent.FLAGS_UPDATED + assert event.metadata.get_type() == SdkEventType.FLAG_UPDATE + assert event.metadata.get_names() == {'some_split'} + + split2 = Split('another_split', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1) + storage.update([split2], ['some_split'], 1) + event = events_queue.get() + assert event.internal_event == SdkInternalEvent.FLAGS_UPDATED + assert event.metadata.get_type() == SdkEventType.FLAG_UPDATE + assert event.metadata.get_names() == {'another_split', 'some_split'} + + storage.kill_locally('another_split', 'default_treatment', 3) + event = events_queue.get() + assert event.internal_event == SdkInternalEvent.FLAG_KILLED_NOTIFICATION + assert event.metadata.get_type() == SdkEventType.FLAG_UPDATE + assert event.metadata.get_names() == {'another_split'} + +class InMemorySplitStorageAsyncTests(object): + """In memory split storage test cases.""" + + @pytest.mark.asyncio + async def test_storing_retrieving_splits(self, mocker): + """Test storing and retrieving splits works.""" + storage = InMemorySplitStorageAsync(asyncio.Queue()) + + split = mocker.Mock(spec=Split) + name_property = mocker.PropertyMock() + name_property.return_value = 'some_split' + type(split).name = name_property + sets_property = mocker.PropertyMock() + sets_property.return_value = ['set_1'] + type(split).sets = sets_property + + await storage.update([split], [], -1) + assert await storage.get('some_split') == split + assert await storage.get_split_names() == ['some_split'] + assert await storage.get_all_splits() == [split] + assert await storage.get('nonexistant_split') is None + + await storage.update([], ['some_split'], -1) + assert await storage.get('some_split') is None + + @pytest.mark.asyncio + async def test_get_splits(self, mocker): + """Test retrieving a list of passed splits.""" + split1 = mocker.Mock() + name1_prop = mocker.PropertyMock() + name1_prop.return_value = 'split1' + type(split1).name = name1_prop + split2 = mocker.Mock() + name2_prop = mocker.PropertyMock() + name2_prop.return_value = 'split2' + type(split2).name = name2_prop + sets_property = mocker.PropertyMock() + sets_property.return_value = ['set_1'] + type(split1).sets = sets_property + type(split2).sets = sets_property + + storage = InMemorySplitStorageAsync(asyncio.Queue()) + await storage.update([split1, split2], [], -1) + + splits = await storage.fetch_many(['split1', 'split2', 'split3']) + assert len(splits) == 3 + assert splits['split1'].name == 'split1' + assert splits['split2'].name == 'split2' + assert 'split3' in splits + + @pytest.mark.asyncio + async def test_store_get_changenumber(self): + """Test that storing and retrieving change numbers works.""" + storage = InMemorySplitStorageAsync(asyncio.Queue()) + assert await storage.get_change_number() == -1 + await storage.update([], [], 5) + assert await storage.get_change_number() == 5 + + @pytest.mark.asyncio + async def test_get_split_names(self, mocker): + """Test retrieving a list of all split names.""" + split1 = mocker.Mock() + name1_prop = mocker.PropertyMock() + name1_prop.return_value = 'split1' + type(split1).name = name1_prop + split2 = mocker.Mock() + name2_prop = mocker.PropertyMock() + name2_prop.return_value = 'split2' + type(split2).name = name2_prop + sets_property = mocker.PropertyMock() + sets_property.return_value = ['set_1'] + type(split1).sets = sets_property + type(split2).sets = sets_property + + storage = InMemorySplitStorageAsync(asyncio.Queue()) + await storage.update([split1, split2], [], -1) + assert set(await storage.get_split_names()) == set(['split1', 'split2']) + + @pytest.mark.asyncio + async def test_get_all_splits(self, mocker): + """Test retrieving a list of all split names.""" + split1 = mocker.Mock() + name1_prop = mocker.PropertyMock() + name1_prop.return_value = 'split1' + type(split1).name = name1_prop + split2 = mocker.Mock() + name2_prop = mocker.PropertyMock() + name2_prop.return_value = 'split2' + type(split2).name = name2_prop + sets_property = mocker.PropertyMock() + sets_property.return_value = ['set_1'] + type(split1).sets = sets_property + type(split2).sets = sets_property + + storage = InMemorySplitStorageAsync(asyncio.Queue()) + await storage.update([split1, split2], [], -1) + + all_splits = await storage.get_all_splits() + assert next(s for s in all_splits if s.name == 'split1') + assert next(s for s in all_splits if s.name == 'split2') + + @pytest.mark.asyncio + async def test_is_valid_traffic_type(self, mocker): + """Test that traffic type validation works properly.""" + split1 = mocker.Mock() + name1_prop = mocker.PropertyMock() + name1_prop.return_value = 'split1' + type(split1).name = name1_prop + split2 = mocker.Mock() + name2_prop = mocker.PropertyMock() + name2_prop.return_value = 'split2' + type(split2).name = name2_prop + split3 = mocker.Mock() + tt_user = mocker.PropertyMock() + tt_user.return_value = 'user' + tt_account = mocker.PropertyMock() + tt_account.return_value = 'account' + name3_prop = mocker.PropertyMock() + name3_prop.return_value = 'split3' + type(split3).name = name3_prop + type(split1).traffic_type_name = tt_user + type(split2).traffic_type_name = tt_account + type(split3).traffic_type_name = tt_user + sets_property = mocker.PropertyMock() + sets_property.return_value = ['set_1'] + type(split1).sets = sets_property + type(split2).sets = sets_property + type(split3).sets = sets_property + + storage = InMemorySplitStorageAsync(asyncio.Queue()) + + await storage.update([split1], [], -1) + assert await storage.is_valid_traffic_type('user') is True + assert await storage.is_valid_traffic_type('account') is False + + await storage.update([split2], [], -1) + assert await storage.is_valid_traffic_type('user') is True + assert await storage.is_valid_traffic_type('account') is True + + await storage.update([split3], [], -1) + assert await storage.is_valid_traffic_type('user') is True + assert await storage.is_valid_traffic_type('account') is True + + await storage.update([], ['split1'], -1) + assert await storage.is_valid_traffic_type('user') is True + assert await storage.is_valid_traffic_type('account') is True + + await storage.update([], ['split2'], -1) + assert await storage.is_valid_traffic_type('user') is True + assert await storage.is_valid_traffic_type('account') is False + + await storage.update([], ['split3'], -1) + assert await storage.is_valid_traffic_type('user') is False + assert await storage.is_valid_traffic_type('account') is False + + @pytest.mark.asyncio + async def test_traffic_type_inc_dec_logic(self, mocker): + """Test that adding/removing split, handles traffic types correctly.""" + storage = InMemorySplitStorageAsync(asyncio.Queue()) + + split1 = mocker.Mock() + name1_prop = mocker.PropertyMock() + name1_prop.return_value = 'split1' + type(split1).name = name1_prop + + split2 = mocker.Mock() + name2_prop = mocker.PropertyMock() + name2_prop.return_value = 'split1' + type(split2).name = name2_prop + tt_user = mocker.PropertyMock() + tt_user.return_value = 'user' + tt_account = mocker.PropertyMock() + tt_account.return_value = 'account' + type(split1).traffic_type_name = tt_user + type(split2).traffic_type_name = tt_account + sets_property = mocker.PropertyMock() + sets_property.return_value = ['set_1'] + type(split1).sets = sets_property + type(split2).sets = sets_property + + await storage.update([split1], [], -1) + assert await storage.is_valid_traffic_type('user') is True + assert await storage.is_valid_traffic_type('account') is False + + await storage.update([split2], [], -1) + assert await storage.is_valid_traffic_type('user') is False + assert await storage.is_valid_traffic_type('account') is True + + @pytest.mark.asyncio + async def test_kill_locally(self): + """Test kill local.""" + storage = InMemorySplitStorageAsync(asyncio.Queue()) + + split = Split('some_split', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1) + await storage.update([split], [], 1) + + await storage.kill_locally('test', 'default_treatment', 2) + assert await storage.get('test') is None + + await storage.kill_locally('some_split', 'default_treatment', 0) + split = await storage.get('some_split') + assert split.change_number == 1 + assert split.killed is False + assert split.default_treatment == 'some' + + await storage.kill_locally('some_split', 'default_treatment', 3) + split = await storage.get('some_split') + assert split.change_number == 3 + + @pytest.mark.asyncio + async def test_flag_sets_with_config_sets(self): + storage = InMemorySplitStorageAsync(asyncio.Queue(), ['set10', 'set02', 'set05']) + assert storage.flag_set_filter.flag_sets == {'set10', 'set02', 'set05'} + assert storage.flag_set_filter.should_filter + + assert storage.flag_set.sets_feature_flag_map == {'set10': set(), 'set02': set(), 'set05': set()} + + split1 = Split('split1', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set10', 'set02']) + split2 = Split('split2', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set05', 'set02']) + split3 = Split('split3', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set04', 'set05']) + await storage.update([split1], [], 1) + assert await storage.get_feature_flags_by_sets(['set10']) == ['split1'] + assert await storage.get_feature_flags_by_sets(['set02']) == ['split1'] + assert await storage.get_feature_flags_by_sets(['set02', 'set10']) == ['split1'] + assert await storage.is_flag_set_exist('set10') + assert await storage.is_flag_set_exist('set02') + assert not await storage.is_flag_set_exist('set03') + + await storage.update([split2], [], 1) + assert await storage.get_feature_flags_by_sets(['set05']) == ['split2'] + assert sorted(await storage.get_feature_flags_by_sets(['set02', 'set05'])) == ['split1', 'split2'] + assert await storage.is_flag_set_exist('set05') + + await storage.update([], [split2.name], 1) + assert await storage.is_flag_set_exist('set05') + assert await storage.get_feature_flags_by_sets(['set02']) == ['split1'] + assert await storage.get_feature_flags_by_sets(['set05']) == [] + + split1 = Split('split1', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set02']) + await storage.update([split1], [], 1) + assert await storage.is_flag_set_exist('set10') + assert await storage.get_feature_flags_by_sets(['set02']) == ['split1'] + + await storage.update([], [split1.name], 1) + assert await storage.get_feature_flags_by_sets(['set02']) == [] + assert storage.flag_set.sets_feature_flag_map == {'set10': set(), 'set02': set(), 'set05': set()} + + await storage.update([split3], [], 1) + assert await storage.get_feature_flags_by_sets(['set05']) == ['split3'] + assert not await storage.is_flag_set_exist('set04') + + @pytest.mark.asyncio + async def test_flag_sets_withut_config_sets(self): + storage = InMemorySplitStorageAsync(asyncio.Queue()) + assert storage.flag_set_filter.flag_sets == set({}) + assert not storage.flag_set_filter.should_filter + + assert storage.flag_set.sets_feature_flag_map == {} + + split1 = Split('split1', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set10', 'set02']) + split2 = Split('split2', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set05', 'set02']) + split3 = Split('split3', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set04', 'set05']) + await storage.update([split1], [], 1) + assert await storage.get_feature_flags_by_sets(['set10']) == ['split1'] + assert await storage.get_feature_flags_by_sets(['set02']) == ['split1'] + assert await storage.is_flag_set_exist('set10') + assert await storage.is_flag_set_exist('set02') + assert not await storage.is_flag_set_exist('set03') + + await storage.update([split2], [], 1) + assert await storage.get_feature_flags_by_sets(['set05']) == ['split2'] + assert sorted(await storage.get_feature_flags_by_sets(['set02', 'set05'])) == ['split1', 'split2'] + assert await storage.is_flag_set_exist('set05') + + await storage.update([], [split2.name], 1) + assert not await storage.is_flag_set_exist('set05') + assert await storage.get_feature_flags_by_sets(['set02']) == ['split1'] + + split1 = Split('split1', 123456789, False, 'some', 'traffic_type', + 'ACTIVE', 1, sets=['set02']) + await storage.update([split1], [], 1) + assert not await storage.is_flag_set_exist('set10') + assert await storage.get_feature_flags_by_sets(['set02']) == ['split1'] + + await storage.update([], [split1.name], 1) + assert await storage.get_feature_flags_by_sets(['set02']) == [] + assert storage.flag_set.sets_feature_flag_map == {} + + await storage.update([split3], [], 1) + assert await storage.get_feature_flags_by_sets(['set05']) == ['split3'] + assert await storage.get_feature_flags_by_sets(['set04', 'set05']) == ['split3'] + + @pytest.mark.asyncio + async def test_internal_event_notification(self, mocker): + """Test retrieving a list of all split names.""" + split1 = mocker.Mock() + name1_prop = mocker.PropertyMock() + name1_prop.return_value = 'split1' + type(split1).name = name1_prop + split2 = mocker.Mock() + name2_prop = mocker.PropertyMock() + name2_prop.return_value = 'split2' + type(split2).name = name2_prop + sets_property = mocker.PropertyMock() + sets_property.return_value = ['set_1'] + type(split1).sets = sets_property + type(split2).sets = sets_property + events_queue = asyncio.Queue() + storage = InMemorySplitStorageAsync(events_queue) + await storage.update([split1, split2], [], -1) + event = await events_queue.get() + assert event.internal_event == SdkInternalEvent.FLAGS_UPDATED + assert event.metadata.get_type() == SdkEventType.FLAG_UPDATE + assert event.metadata.get_names() == {'split1', 'split2'} + + await storage.kill_locally('split1', 'default_treatment', 3) + event = await events_queue.get() + assert event.internal_event == SdkInternalEvent.FLAG_KILLED_NOTIFICATION + assert event.metadata.get_type() == SdkEventType.FLAG_UPDATE + assert event.metadata.get_names() == {'split1'} + class InMemorySegmentStorageTests(object): """In memory segment storage tests.""" def test_segment_storage_retrieval(self, mocker): """Test storing and retrieving segments.""" - storage = InMemorySegmentStorage() + events_queue = queue.Queue() + storage = InMemorySegmentStorage(events_queue) segment = mocker.Mock(spec=Segment) name_property = mocker.PropertyMock() name_property.return_value = 'some_segment' @@ -211,14 +756,16 @@ def test_segment_storage_retrieval(self, mocker): def test_change_number(self, mocker): """Test storing and retrieving segment changeNumber.""" - storage = InMemorySegmentStorage() + events_queue = queue.Queue() + storage = InMemorySegmentStorage(events_queue) storage.set_change_number('some_segment', 123) # Change number is not updated if segment doesn't exist assert storage.get_change_number('some_segment') is None assert storage.get_change_number('nonexistant-segment') is None # Change number is updated if segment does exist. - storage = InMemorySegmentStorage() + events_queue = queue.Queue() + storage = InMemorySegmentStorage(events_queue) segment = mocker.Mock(spec=Segment) name_property = mocker.PropertyMock() name_property.return_value = 'some_segment' @@ -229,7 +776,8 @@ def test_change_number(self, mocker): def test_segment_contains(self, mocker): """Test using storage to determine whether a key belongs to a segment.""" - storage = InMemorySegmentStorage() + events_queue = queue.Queue() + storage = InMemorySegmentStorage(events_queue) segment = mocker.Mock(spec=Segment) name_property = mocker.PropertyMock() name_property.return_value = 'some_segment' @@ -241,7 +789,8 @@ def test_segment_contains(self, mocker): def test_segment_update(self): """Test updating a segment.""" - storage = InMemorySegmentStorage() + events_queue = queue.Queue() + storage = InMemorySegmentStorage(events_queue) segment = Segment('some_segment', ['key1', 'key2', 'key3'], 123) storage.put(segment) assert storage.get('some_segment') == segment @@ -254,75 +803,280 @@ def test_segment_update(self): assert not storage.segment_contains('some_segment', 'key3') assert storage.get_change_number('some_segment') == 456 + def test_internal_event_notification(self): + """Test updating a segment.""" + events_queue = queue.Queue() + storage = InMemorySegmentStorage(events_queue) + segment = Segment('some_segment', ['key1', 'key2', 'key3'], 123) + storage.put(segment) + event = events_queue.get() + assert event.internal_event == SdkInternalEvent.SEGMENTS_UPDATED + assert event.metadata.get_type() == SdkEventType.SEGMENTS_UPDATE + assert len(event.metadata.get_names()) == 0 + + storage.update('some_segment', ['key4', 'key5'], ['key2', 'key3'], 456) + event = events_queue.get() + assert event.internal_event == SdkInternalEvent.SEGMENTS_UPDATED + assert event.metadata.get_type() == SdkEventType.SEGMENTS_UPDATE + assert len(event.metadata.get_names()) == 0 + +class InMemorySegmentStorageAsyncTests(object): + """In memory segment storage tests.""" + + @pytest.mark.asyncio + async def test_segment_storage_retrieval(self, mocker): + """Test storing and retrieving segments.""" + storage = InMemorySegmentStorageAsync(asyncio.Queue()) + segment = mocker.Mock(spec=Segment) + name_property = mocker.PropertyMock() + name_property.return_value = 'some_segment' + type(segment).name = name_property + + await storage.put(segment) + assert await storage.get('some_segment') == segment + assert await storage.get('nonexistant-segment') is None + + @pytest.mark.asyncio + async def test_change_number(self, mocker): + """Test storing and retrieving segment changeNumber.""" + storage = InMemorySegmentStorageAsync(asyncio.Queue()) + await storage.set_change_number('some_segment', 123) + # Change number is not updated if segment doesn't exist + assert await storage.get_change_number('some_segment') is None + assert await storage.get_change_number('nonexistant-segment') is None + + # Change number is updated if segment does exist. + storage = InMemorySegmentStorageAsync(asyncio.Queue()) + segment = mocker.Mock(spec=Segment) + name_property = mocker.PropertyMock() + name_property.return_value = 'some_segment' + type(segment).name = name_property + await storage.put(segment) + await storage.set_change_number('some_segment', 123) + assert await storage.get_change_number('some_segment') == 123 + + @pytest.mark.asyncio + async def test_segment_contains(self, mocker): + """Test using storage to determine whether a key belongs to a segment.""" + storage = InMemorySegmentStorageAsync(asyncio.Queue()) + segment = mocker.Mock(spec=Segment) + name_property = mocker.PropertyMock() + name_property.return_value = 'some_segment' + type(segment).name = name_property + await storage.put(segment) + + await storage.segment_contains('some_segment', 'abc') + assert segment.contains.mock_calls[0] == mocker.call('abc') + + @pytest.mark.asyncio + async def test_segment_update(self): + """Test updating a segment.""" + storage = InMemorySegmentStorageAsync(asyncio.Queue()) + segment = Segment('some_segment', ['key1', 'key2', 'key3'], 123) + await storage.put(segment) + assert await storage.get('some_segment') == segment + + await storage.update('some_segment', ['key4', 'key5'], ['key2', 'key3'], 456) + assert await storage.segment_contains('some_segment', 'key1') + assert await storage.segment_contains('some_segment', 'key4') + assert await storage.segment_contains('some_segment', 'key5') + assert not await storage.segment_contains('some_segment', 'key2') + assert not await storage.segment_contains('some_segment', 'key3') + assert await storage.get_change_number('some_segment') == 456 + + @pytest.mark.asyncio + async def test_internal_event_notification(self): + """Test updating a segment.""" + events_queue = asyncio.Queue() + storage = InMemorySegmentStorageAsync(events_queue) + segment = Segment('some_segment', ['key1', 'key2', 'key3'], 123) + await storage.put(segment) + event = await events_queue.get() + assert event.internal_event == SdkInternalEvent.SEGMENTS_UPDATED + assert event.metadata.get_type() == SdkEventType.SEGMENTS_UPDATE + assert len(event.metadata.get_names()) == 0 + + await storage.update('some_segment', ['key4', 'key5'], ['key2', 'key3'], 456) + event = await events_queue.get() + assert event.internal_event == SdkInternalEvent.SEGMENTS_UPDATED + assert event.metadata.get_type() == SdkEventType.SEGMENTS_UPDATE + assert len(event.metadata.get_names()) == 0 class InMemoryImpressionsStorageTests(object): """InMemory impressions storage test cases.""" - def test_push_pop_impressions(self): + def test_push_pop_impressions(self, mocker): """Test pushing and retrieving impressions.""" - storage = InMemoryImpressionStorage(100) - storage.put([Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654)]) - storage.put([Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654)]) - storage.put([Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654)]) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + storage = InMemoryImpressionStorage(100, telemetry_runtime_producer) + storage.put([Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None)]) + storage.put([Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None)]) + storage.put([Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None)]) + assert(telemetry_storage._counters._impressions_queued == 3) # Assert impressions are retrieved in the same order they are inserted. assert storage.pop_many(1) == [ - Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) ] assert storage.pop_many(1) == [ - Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) ] assert storage.pop_many(1) == [ - Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) ] # Assert inserting multiple impressions at once works and maintains order. impressions = [ - Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654), - Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654), - Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None), + Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None), + Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) ] assert storage.put(impressions) # Assert impressions are retrieved in the same order they are inserted. assert storage.pop_many(1) == [ - Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) ] assert storage.pop_many(1) == [ - Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) ] assert storage.pop_many(1) == [ - Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) ] def test_queue_full_hook(self, mocker): """Test queue_full_hook is executed when the queue is full.""" - storage = InMemoryImpressionStorage(100) + storage = InMemoryImpressionStorage(100, mocker.Mock()) queue_full_hook = mocker.Mock() storage.set_queue_full_hook(queue_full_hook) impressions = [ - Impression('key%d' % i, 'feature1', 'on', 'l1', 123456, 'b1', 321654) + Impression('key%d' % i, 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) for i in range(0, 101) ] storage.put(impressions) assert queue_full_hook.mock_calls == mocker.call() - def test_clear(self): + def test_clear(self, mocker): """Test clear method.""" - storage = InMemoryImpressionStorage(100) - storage.put([Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654)]) + storage = InMemoryImpressionStorage(100, mocker.Mock()) + storage.put([Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None)]) assert storage._impressions.qsize() == 1 storage.clear() assert storage._impressions.qsize() == 0 + def test_impressions_dropped(self, mocker): + """Test pushing and retrieving impressions.""" + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + storage = InMemoryImpressionStorage(2, telemetry_runtime_producer) + storage.put([Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None)]) + storage.put([Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None)]) + storage.put([Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None)]) + assert(telemetry_storage._counters._impressions_dropped == 1) + assert(telemetry_storage._counters._impressions_queued == 2) + + +class InMemoryImpressionsStorageAsyncTests(object): + """InMemory impressions async storage test cases.""" + + @pytest.mark.asyncio + async def test_push_pop_impressions(self, mocker): + """Test pushing and retrieving impressions.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + storage = InMemoryImpressionStorageAsync(100, telemetry_runtime_producer) + await storage.put([Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None)]) + await storage.put([Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None)]) + await storage.put([Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None)]) + assert(telemetry_storage._counters._impressions_queued == 3) + + # Assert impressions are retrieved in the same order they are inserted. + assert await storage.pop_many(1) == [ + Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) + ] + assert await storage.pop_many(1) == [ + Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) + ] + assert await storage.pop_many(1) == [ + Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) + ] + + # Assert inserting multiple impressions at once works and maintains order. + impressions = [ + Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None), + Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None), + Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) + ] + assert await storage.put(impressions) + + # Assert impressions are retrieved in the same order they are inserted. + assert await storage.pop_many(1) == [ + Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) + ] + assert await storage.pop_many(1) == [ + Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) + ] + assert await storage.pop_many(1) == [ + Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) + ] + + @pytest.mark.asyncio + async def test_queue_full_hook(self, mocker): + """Test queue_full_hook is executed when the queue is full.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + storage = InMemoryImpressionStorageAsync(100, telemetry_runtime_producer) + self.hook_called = False + async def queue_full_hook(): + self.hook_called = True + + storage.set_queue_full_hook(queue_full_hook) + impressions = [ + Impression('key%d' % i, 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None) + for i in range(0, 101) + ] + await storage.put(impressions) + await queue_full_hook() + assert self.hook_called == True + + @pytest.mark.asyncio + async def test_clear(self, mocker): + """Test clear method.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + storage = InMemoryImpressionStorageAsync(100, telemetry_runtime_producer) + await storage.put([Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None)]) + assert storage._impressions.qsize() == 1 + await storage.clear() + assert storage._impressions.qsize() == 0 + + @pytest.mark.asyncio + async def test_impressions_dropped(self, mocker): + """Test pushing and retrieving impressions.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + storage = InMemoryImpressionStorageAsync(2, telemetry_runtime_producer) + await storage.put([Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None)]) + await storage.put([Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None)]) + await storage.put([Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654, None, None)]) + assert(telemetry_storage._counters._impressions_dropped == 1) + assert(telemetry_storage._counters._impressions_queued == 2) + class InMemoryEventsStorageTests(object): """InMemory events storage test cases.""" - def test_push_pop_events(self): + def test_push_pop_events(self, mocker): """Test pushing and retrieving events.""" - storage = InMemoryEventStorage(100) + storage = InMemoryEventStorage(100, mocker.Mock()) storage.put([EventWrapper( event=Event('key1', 'user', 'purchase', 3.5, 123456, None), size=1024, @@ -365,7 +1119,7 @@ def test_push_pop_events(self): def test_queue_full_hook(self, mocker): """Test queue_full_hook is executed when the queue is full.""" - storage = InMemoryEventStorage(100) + storage = InMemoryEventStorage(100, mocker.Mock()) queue_full_hook = mocker.Mock() storage.set_queue_full_hook(queue_full_hook) events = [EventWrapper(event=Event('key%d' % i, 'user', 'purchase', 12.5, 321654, None), size=1024) for i in range(0, 101)] @@ -374,16 +1128,16 @@ def test_queue_full_hook(self, mocker): def test_queue_full_hook_properties(self, mocker): """Test queue_full_hook is executed when the queue is full regarding properties.""" - storage = InMemoryEventStorage(200) + storage = InMemoryEventStorage(200, mocker.Mock()) queue_full_hook = mocker.Mock() storage.set_queue_full_hook(queue_full_hook) events = [EventWrapper(event=Event('key%d' % i, 'user', 'purchase', 12.5, 1, None), size=32768) for i in range(160)] storage.put(events) assert queue_full_hook.mock_calls == [mocker.call()] - def test_clear(self): + def test_clear(self, mocker): """Test clear method.""" - storage = InMemoryEventStorage(100) + storage = InMemoryEventStorage(100, mocker.Mock()) storage.put([EventWrapper( event=Event('key1', 'user', 'purchase', 3.5, 123456, None), size=1024, @@ -392,3 +1146,959 @@ def test_clear(self): assert storage._events.qsize() == 1 storage.clear() assert storage._events.qsize() == 0 + + def test_event_telemetry(self, mocker): + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + storage = InMemoryEventStorage(2, telemetry_runtime_producer) + storage.put([EventWrapper( + event=Event('key1', 'user', 'purchase', 3.5, 123456, None), + size=1024, + )]) + storage.put([EventWrapper( + event=Event('key1', 'user', 'purchase', 3.5, 123456, None), + size=1024, + )]) + storage.put([EventWrapper( + event=Event('key1', 'user', 'purchase', 3.5, 123456, None), + size=1024, + )]) + assert(telemetry_storage._counters._events_dropped == 1) + assert(telemetry_storage._counters._events_queued == 2) + + +class InMemoryEventsStorageAsyncTests(object): + """InMemory events async storage test cases.""" + + @pytest.mark.asyncio + async def test_push_pop_events(self, mocker): + """Test pushing and retrieving events.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + storage = InMemoryEventStorageAsync(100, telemetry_runtime_producer) + await storage.put([EventWrapper( + event=Event('key1', 'user', 'purchase', 3.5, 123456, None), + size=1024, + )]) + await storage.put([EventWrapper( + event=Event('key2', 'user', 'purchase', 3.5, 123456, None), + size=1024, + )]) + await storage.put([EventWrapper( + event=Event('key3', 'user', 'purchase', 3.5, 123456, None), + size=1024, + )]) + + # Assert impressions are retrieved in the same order they are inserted. + assert await storage.pop_many(1) == [Event('key1', 'user', 'purchase', 3.5, 123456, None)] + assert await storage.pop_many(1) == [Event('key2', 'user', 'purchase', 3.5, 123456, None)] + assert await storage.pop_many(1) == [Event('key3', 'user', 'purchase', 3.5, 123456, None)] + + # Assert inserting multiple impressions at once works and maintains order. + events = [ + EventWrapper( + event=Event('key1', 'user', 'purchase', 3.5, 123456, None), + size=1024, + ), + EventWrapper( + event=Event('key2', 'user', 'purchase', 3.5, 123456, None), + size=1024, + ), + EventWrapper( + event=Event('key3', 'user', 'purchase', 3.5, 123456, None), + size=1024, + ), + ] + assert await storage.put(events) + + # Assert events are retrieved in the same order they are inserted. + assert await storage.pop_many(1) == [Event('key1', 'user', 'purchase', 3.5, 123456, None)] + assert await storage.pop_many(1) == [Event('key2', 'user', 'purchase', 3.5, 123456, None)] + assert await storage.pop_many(1) == [Event('key3', 'user', 'purchase', 3.5, 123456, None)] + + @pytest.mark.asyncio + async def test_queue_full_hook(self, mocker): + """Test queue_full_hook is executed when the queue is full.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + storage = InMemoryEventStorageAsync(100, telemetry_runtime_producer) + self.called = False + async def queue_full_hook(): + self.called = True + + storage.set_queue_full_hook(queue_full_hook) + events = [EventWrapper(event=Event('key%d' % i, 'user', 'purchase', 12.5, 321654, None), size=1024) for i in range(0, 101)] + await storage.put(events) + assert self.called == True + + @pytest.mark.asyncio + async def test_queue_full_hook_properties(self, mocker): + """Test queue_full_hook is executed when the queue is full regarding properties.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + storage = InMemoryEventStorageAsync(200, telemetry_runtime_producer) + self.called = False + async def queue_full_hook(): + self.called = True + storage.set_queue_full_hook(queue_full_hook) + events = [EventWrapper(event=Event('key%d' % i, 'user', 'purchase', 12.5, 1, None), size=32768) for i in range(160)] + await storage.put(events) + assert self.called == True + + @pytest.mark.asyncio + async def test_clear(self, mocker): + """Test clear method.""" + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + storage = InMemoryEventStorageAsync(100, telemetry_runtime_producer) + await storage.put([EventWrapper( + event=Event('key1', 'user', 'purchase', 3.5, 123456, None), + size=1024, + )]) + + assert storage._events.qsize() == 1 + await storage.clear() + assert storage._events.qsize() == 0 + + @pytest.mark.asyncio + async def test_event_telemetry(self, mocker): + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + storage = InMemoryEventStorageAsync(2, telemetry_runtime_producer) + await storage.put([EventWrapper( + event=Event('key1', 'user', 'purchase', 3.5, 123456, None), + size=1024, + )]) + await storage.put([EventWrapper( + event=Event('key1', 'user', 'purchase', 3.5, 123456, None), + size=1024, + )]) + await storage.put([EventWrapper( + event=Event('key1', 'user', 'purchase', 3.5, 123456, None), + size=1024, + )]) + assert(telemetry_storage._counters._events_dropped == 1) + assert(telemetry_storage._counters._events_queued == 2) + + +class InMemoryTelemetryStorageTests(object): + """InMemory telemetry storage test cases.""" + + def test_resets(self): + storage = InMemoryTelemetryStorage() + + assert(storage._counters._impressions_queued == 0) + assert(storage._counters._impressions_deduped == 0) + assert(storage._counters._impressions_dropped == 0) + assert(storage._counters._events_dropped == 0) + assert(storage._counters._events_queued == 0) + assert(storage._counters._auth_rejections == 0) + assert(storage._counters._token_refreshes == 0) + + assert(storage._method_exceptions.pop_all() == {'methodExceptions': {'treatment': 0, 'treatments': 0, 'treatment_with_config': 0, 'treatments_with_config': 0, 'treatments_by_flag_set': 0, 'treatments_by_flag_sets': 0, 'treatments_with_config_by_flag_set': 0, 'treatments_with_config_by_flag_sets': 0, 'track': 0}}) + assert(storage._last_synchronization.get_all() == {'lastSynchronizations': {'split': 0, 'segment': 0, 'impression': 0, 'impressionCount': 0, 'event': 0, 'telemetry': 0, 'token': 0}}) + assert(storage._http_sync_errors.pop_all() == {'httpErrors': {'split': {}, 'segment': {}, 'impression': {}, 'impressionCount': {}, 'event': {}, 'telemetry': {}, 'token': {}}}) + assert(storage._tel_config.get_stats() == { + 'bT':0, + 'nR':0, + 'tR': 0, + 'oM': None, + 'sT': None, + 'sE': None, + 'rR': {'sp': 0, 'se': 0, 'im': 0, 'ev': 0, 'te': 0}, + 'uO': {'s': False, 'e': False, 'a': False, 'st': False, 't': False}, + 'iQ': 0, + 'eQ': 0, + 'iM': None, + 'iL': False, + 'hp': None, + 'aF': 0, + 'rF': 0, + 'fsT': 0, + 'fsI': 0 + }) + assert(storage._streaming_events.pop_streaming_events() == {'streamingEvents': []}) + assert(storage._tags == []) + + assert(storage._method_latencies.pop_all() == {'methodLatencies': {'treatment': [0] * 23, 'treatments': [0] * 23, 'treatment_with_config': [0] * 23, 'treatments_with_config': [0] * 23, 'treatments_by_flag_set': [0] * 23, 'treatments_by_flag_sets': [0] * 23, 'treatments_with_config_by_flag_set': [0] * 23, 'treatments_with_config_by_flag_sets': [0] * 23, 'track': [0] * 23}}) + assert(storage._http_latencies.pop_all() == {'httpLatencies': {'split': [0] * 23, 'segment': [0] * 23, 'impression': [0] * 23, 'impressionCount': [0] * 23, 'event': [0] * 23, 'telemetry': [0] * 23, 'token': [0] * 23}}) + + def test_record_config(self): + storage = InMemoryTelemetryStorage() + config = {'operationMode': 'standalone', + 'streamingEnabled': True, + 'impressionsQueueSize': 100, + 'eventsQueueSize': 200, + 'impressionsMode': 'DEBUG','' + 'impressionListener': None, + 'featuresRefreshRate': 30, + 'segmentsRefreshRate': 30, + 'impressionsRefreshRate': 60, + 'eventsPushRate': 60, + 'metricsRefreshRate': 10, + 'storageType': None + } + storage.record_config(config, {}, 2, 1) + storage.record_active_and_redundant_factories(1, 0) + assert(storage._tel_config.get_stats() == {'oM': 0, + 'sT': storage._tel_config._get_storage_type(config['operationMode'], config['storageType']), + 'sE': config['streamingEnabled'], + 'rR': {'sp': 30, 'se': 30, 'im': 60, 'ev': 60, 'te': 10}, + 'uO': {'s': False, 'e': False, 'a': False, 'st': False, 't': False}, + 'iQ': config['impressionsQueueSize'], + 'eQ': config['eventsQueueSize'], + 'iM': storage._tel_config._get_impressions_mode(config['impressionsMode']), + 'iL': True if config['impressionListener'] is not None else False, + 'hp': storage._tel_config._check_if_proxy_detected(), + 'bT': 0, + 'tR': 0, + 'nR': 0, + 'aF': 1, + 'rF': 0, + 'fsT': 2, + 'fsI': 1} + ) + + def test_record_counters(self): + storage = InMemoryTelemetryStorage() + + storage.record_ready_time(10) + assert(storage._tel_config._time_until_ready == 10) + + storage.add_tag('tag') + assert('tag' in storage._tags) + [storage.add_tag('tag') for i in range(1, 25)] + assert(len(storage._tags) == 10) + + storage.record_bur_time_out() + storage.record_bur_time_out() + assert(storage._tel_config.get_bur_time_outs() == 2) + + storage.record_not_ready_usage() + storage.record_not_ready_usage() + assert(storage._tel_config.get_non_ready_usage() == 2) + + storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT) + assert(storage._method_exceptions._treatment == 1) + + storage.record_impression_stats(ModelTelemetry.CounterConstants.IMPRESSIONS_QUEUED, 5) + assert(storage._counters.get_counter_stats(ModelTelemetry.CounterConstants.IMPRESSIONS_QUEUED) == 5) + + storage.record_event_stats(ModelTelemetry.CounterConstants.EVENTS_DROPPED, 6) + assert(storage._counters.get_counter_stats(ModelTelemetry.CounterConstants.EVENTS_DROPPED) == 6) + + storage.record_successful_sync(ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT, 10) + assert(storage._last_synchronization._segment == 10) + + storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT, '500') + assert(storage._http_sync_errors._segment['500'] == 1) + + storage.record_auth_rejections() + storage.record_auth_rejections() + assert(storage._counters.pop_auth_rejections() == 2) + + storage.record_token_refreshes() + storage.record_token_refreshes() + assert(storage._counters.pop_token_refreshes() == 2) + + storage.record_streaming_event((ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED, 'split', 1234)) + assert(storage._streaming_events.pop_streaming_events() == {'streamingEvents': [{'e': ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED.value, 'd': 'split', 't': 1234}]}) + [storage.record_streaming_event((ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED, 'split', 1234)) for i in range(1, 25)] + assert(len(storage._streaming_events._streaming_events) == 20) + + storage.record_session_length(20) + assert(storage._counters.get_session_length() == 20) + + def test_record_latencies(self): + storage = InMemoryTelemetryStorage() + + for method in ModelTelemetry.MethodExceptionsAndLatencies: + if self._get_method_latency(method, storage) == None: + continue + storage.record_latency(method, 50) + assert(self._get_method_latency(method, storage)[ModelTelemetry.get_latency_bucket_index(50)] == 1) + storage.record_latency(method, 50000000) + assert(self._get_method_latency(method, storage)[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + for j in range(10): + latency = random.randint(1001, 4987885) + current_count = self._get_method_latency(method, storage)[ModelTelemetry.get_latency_bucket_index(latency)] + [storage.record_latency(method, latency) for i in range(2)] + assert(self._get_method_latency(method, storage)[ModelTelemetry.get_latency_bucket_index(latency)] == 2 + current_count) + + for resource in ModelTelemetry.HTTPExceptionsAndLatencies: + if self._get_http_latency(resource, storage) == None: + continue + storage.record_sync_latency(resource, 50) + assert(self._get_http_latency(resource, storage)[ModelTelemetry.get_latency_bucket_index(50)] == 1) + storage.record_sync_latency(resource, 50000000) + assert(self._get_http_latency(resource, storage)[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + for j in range(10): + latency = random.randint(1001, 4987885) + current_count = self._get_http_latency(resource, storage)[ModelTelemetry.get_latency_bucket_index(latency)] + [storage.record_sync_latency(resource, latency) for i in range(2)] + assert(self._get_http_latency(resource, storage)[ModelTelemetry.get_latency_bucket_index(latency)] == 2 + current_count) + + def _get_method_latency(self, resource, storage): + if resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT: + return storage._method_latencies._treatment + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS: + return storage._method_latencies._treatments + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG: + return storage._method_latencies._treatment_with_config + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG: + return storage._method_latencies._treatments_with_config + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET: + return storage._method_latencies._treatments_by_flag_set + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS: + return storage._method_latencies._treatments_by_flag_sets + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET: + return storage._method_latencies._treatments_with_config_by_flag_set + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS: + return storage._method_latencies._treatments_with_config_by_flag_sets + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TRACK: + return storage._method_latencies._track + else: + return + + def _get_http_latency(self, resource, storage): + if resource == ModelTelemetry.HTTPExceptionsAndLatencies.SPLIT: + return storage._http_latencies._split + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT: + return storage._http_latencies._segment + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION: + return storage._http_latencies._impression + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION_COUNT: + return storage._http_latencies._impression_count + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.EVENT: + return storage._http_latencies._event + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.TELEMETRY: + return storage._http_latencies._telemetry + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.TOKEN: + return storage._http_latencies._token + else: + return + + def test_pop_counters(self): + storage = InMemoryTelemetryStorage() + + [storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT) for i in range(2)] + storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS) + storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG) + [storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG) for i in range(5)] + [storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET) for i in range(3)] + [storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS) for i in range(10)] + [storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET) for i in range(7)] + [storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS) for i in range(6)] + [storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TRACK) for i in range(3)] + exceptions = storage.pop_exceptions() + assert(storage._method_exceptions._treatment == 0) + assert(storage._method_exceptions._treatments == 0) + assert(storage._method_exceptions._treatment_with_config == 0) + assert(storage._method_exceptions._treatments_with_config == 0) + assert(storage._method_exceptions._treatments_by_flag_set == 0) + assert(storage._method_exceptions._treatments_by_flag_sets == 0) + assert(storage._method_exceptions._track == 0) + assert(storage._method_exceptions._treatments_with_config_by_flag_set == 0) + assert(storage._method_exceptions._treatments_with_config_by_flag_sets == 0) + assert(exceptions == {'methodExceptions': {'treatment': 2, 'treatments': 1, 'treatment_with_config': 1, 'treatments_with_config': 5, 'treatments_by_flag_set': 3, 'treatments_by_flag_sets': 10, 'treatments_with_config_by_flag_set': 7, 'treatments_with_config_by_flag_sets': 6, 'track': 3}}) + + storage.add_tag('tag1') + storage.add_tag('tag2') + tags = storage.pop_tags() + assert(storage._tags == []) + assert(tags == ['tag1', 'tag2']) + + [storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT, str(i)) for i in [500, 501, 502]] + [storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.SPLIT, str(i)) for i in [400, 401, 402]] + storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION, '502') + [storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION_COUNT, str(i)) for i in [501, 502]] + storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.EVENT, '501') + storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.TELEMETRY, '505') + [storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.TOKEN, '502') for i in range(5)] + http_errors = storage.pop_http_errors() + assert(http_errors == {'httpErrors': {'split': {'400': 1, '401': 1, '402': 1}, 'segment': {'500': 1, '501': 1, '502': 1}, + 'impression': {'502': 1}, 'impressionCount': {'501': 1, '502': 1}, + 'event': {'501': 1}, 'telemetry': {'505': 1}, 'token': {'502': 5}}}) + assert(storage._http_sync_errors._split == {}) + assert(storage._http_sync_errors._segment == {}) + assert(storage._http_sync_errors._impression == {}) + assert(storage._http_sync_errors._impression_count == {}) + assert(storage._http_sync_errors._event == {}) + assert(storage._http_sync_errors._telemetry == {}) + + storage.record_auth_rejections() + storage.record_auth_rejections() + auth_rejections = storage.pop_auth_rejections() + assert(storage._counters._auth_rejections == 0) + assert(auth_rejections == 2) + + storage.record_token_refreshes() + storage.record_token_refreshes() + token_refreshes = storage.pop_token_refreshes() + assert(storage._counters._token_refreshes == 0) + assert(token_refreshes == 2) + + storage.record_streaming_event((ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED, 'split', 1234)) + storage.record_streaming_event((ModelTelemetry.StreamingEventTypes.OCCUPANCY_PRI, 'split', 1234)) + streaming_events = storage.pop_streaming_events() + assert(storage._streaming_events._streaming_events == []) + assert(streaming_events == {'streamingEvents': [{'e': ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED.value, 'd': 'split', 't': 1234}, + {'e': ModelTelemetry.StreamingEventTypes.OCCUPANCY_PRI.value, 'd': 'split', 't': 1234}]}) + + def test_pop_latencies(self): + storage = InMemoryTelemetryStorage() + + [storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT, i) for i in [5, 10, 10, 10]] + [storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS, i) for i in [7, 10, 14, 13]] + [storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG, i) for i in [200]] + [storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG, i) for i in [50, 40]] + [storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET, i) for i in [15, 20]] + [storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS, i) for i in [14, 25]] + [storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET, i) for i in [100]] + [storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS, i) for i in [50, 20]] + [storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TRACK, i) for i in [1, 10, 100]] + latencies = storage.pop_latencies() + + assert(storage._method_latencies._treatment == [0] * 23) + assert(storage._method_latencies._treatments == [0] * 23) + assert(storage._method_latencies._treatment_with_config == [0] * 23) + assert(storage._method_latencies._treatments_with_config == [0] * 23) + assert(storage._method_latencies._treatments_by_flag_set == [0] * 23) + assert(storage._method_latencies._treatments_by_flag_sets == [0] * 23) + assert(storage._method_latencies._treatments_with_config_by_flag_set == [0] * 23) + assert(storage._method_latencies._treatments_with_config_by_flag_sets == [0] * 23) + assert(storage._method_latencies._track == [0] * 23) + assert(latencies == {'methodLatencies': { + 'treatment': [4] + [0] * 22, + 'treatments': [4] + [0] * 22, + 'treatment_with_config': [1] + [0] * 22, + 'treatments_with_config': [2] + [0] * 22, + 'treatments_by_flag_set': [2] + [0] * 22, + 'treatments_by_flag_sets': [2] + [0] * 22, + 'treatments_with_config_by_flag_set': [1] + [0] * 22, + 'treatments_with_config_by_flag_sets': [2] + [0] * 22, + 'track': [3] + [0] * 22}}) + + [storage.record_sync_latency(ModelTelemetry.HTTPExceptionsAndLatencies.SPLIT, i) for i in [50, 10, 20, 40]] + [storage.record_sync_latency(ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT, i) for i in [70, 100, 40, 30]] + [storage.record_sync_latency(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION, i) for i in [10, 20]] + [storage.record_sync_latency(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION_COUNT, i) for i in [5, 10]] + [storage.record_sync_latency(ModelTelemetry.HTTPExceptionsAndLatencies.EVENT, i) for i in [50, 40]] + [storage.record_sync_latency(ModelTelemetry.HTTPExceptionsAndLatencies.TELEMETRY, i) for i in [100, 50, 160]] + [storage.record_sync_latency(ModelTelemetry.HTTPExceptionsAndLatencies.TOKEN, i) for i in [10, 15, 100]] + sync_latency = storage.pop_http_latencies() + + assert(storage._http_latencies._split == [0] * 23) + assert(storage._http_latencies._segment == [0] * 23) + assert(storage._http_latencies._impression == [0] * 23) + assert(storage._http_latencies._impression_count == [0] * 23) + assert(storage._http_latencies._telemetry == [0] * 23) + assert(storage._http_latencies._token == [0] * 23) + assert(sync_latency == {'httpLatencies': {'split': [4] + [0] * 22, 'segment': [4] + [0] * 22, + 'impression': [2] + [0] * 22, 'impressionCount': [2] + [0] * 22, 'event': [2] + [0] * 22, + 'telemetry': [3] + [0] * 22, 'token': [3] + [0] * 22}}) + + +class InMemoryTelemetryStorageAsyncTests(object): + """InMemory telemetry async storage test cases.""" + + @pytest.mark.asyncio + async def test_resets(self): + storage = await InMemoryTelemetryStorageAsync.create() + + assert(storage._counters._impressions_queued == 0) + assert(storage._counters._impressions_deduped == 0) + assert(storage._counters._impressions_dropped == 0) + assert(storage._counters._events_dropped == 0) + assert(storage._counters._events_queued == 0) + assert(storage._counters._auth_rejections == 0) + assert(storage._counters._token_refreshes == 0) + + assert(await storage._method_exceptions.pop_all() == {'methodExceptions': {'treatment': 0, 'treatments': 0, 'treatment_with_config': 0, 'treatments_with_config': 0, 'treatments_by_flag_set': 0, 'treatments_by_flag_sets': 0, 'treatments_with_config_by_flag_set': 0, 'treatments_with_config_by_flag_sets': 0, 'track': 0}}) + assert(await storage._last_synchronization.get_all() == {'lastSynchronizations': {'split': 0, 'segment': 0, 'impression': 0, 'impressionCount': 0, 'event': 0, 'telemetry': 0, 'token': 0}}) + assert(await storage._http_sync_errors.pop_all() == {'httpErrors': {'split': {}, 'segment': {}, 'impression': {}, 'impressionCount': {}, 'event': {}, 'telemetry': {}, 'token': {}}}) + assert(await storage._tel_config.get_stats() == { + 'bT':0, + 'nR':0, + 'tR': 0, + 'oM': None, + 'sT': None, + 'sE': None, + 'rR': {'sp': 0, 'se': 0, 'im': 0, 'ev': 0, 'te': 0}, + 'uO': {'s': False, 'e': False, 'a': False, 'st': False, 't': False}, + 'iQ': 0, + 'eQ': 0, + 'iM': None, + 'iL': False, + 'hp': None, + 'aF': 0, + 'rF': 0, + 'fsT': 0, + 'fsI': 0 + }) + assert(await storage._streaming_events.pop_streaming_events() == {'streamingEvents': []}) + assert(storage._tags == []) + + assert(await storage._method_latencies.pop_all() == {'methodLatencies': {'treatment': [0] * 23, 'treatments': [0] * 23, 'treatment_with_config': [0] * 23, 'treatments_with_config': [0] * 23, 'treatments_by_flag_set': [0] * 23, 'treatments_by_flag_sets': [0] * 23, 'treatments_with_config_by_flag_set': [0] * 23, 'treatments_with_config_by_flag_sets': [0] * 23, 'track': [0] * 23}}) + assert(await storage._http_latencies.pop_all() == {'httpLatencies': {'split': [0] * 23, 'segment': [0] * 23, 'impression': [0] * 23, 'impressionCount': [0] * 23, 'event': [0] * 23, 'telemetry': [0] * 23, 'token': [0] * 23}}) + + @pytest.mark.asyncio + async def test_record_config(self): + storage = await InMemoryTelemetryStorageAsync.create() + config = {'operationMode': 'standalone', + 'streamingEnabled': True, + 'impressionsQueueSize': 100, + 'eventsQueueSize': 200, + 'impressionsMode': 'DEBUG','' + 'impressionListener': None, + 'featuresRefreshRate': 30, + 'segmentsRefreshRate': 30, + 'impressionsRefreshRate': 60, + 'eventsPushRate': 60, + 'metricsRefreshRate': 10, + 'storageType': None + } + await storage.record_config(config, {}, 2, 1) + await storage.record_active_and_redundant_factories(1, 0) + assert(await storage._tel_config.get_stats() == {'oM': 0, + 'sT': storage._tel_config._get_storage_type(config['operationMode'], config['storageType']), + 'sE': config['streamingEnabled'], + 'rR': {'sp': 30, 'se': 30, 'im': 60, 'ev': 60, 'te': 10}, + 'uO': {'s': False, 'e': False, 'a': False, 'st': False, 't': False}, + 'iQ': config['impressionsQueueSize'], + 'eQ': config['eventsQueueSize'], + 'iM': storage._tel_config._get_impressions_mode(config['impressionsMode']), + 'iL': True if config['impressionListener'] is not None else False, + 'hp': storage._tel_config._check_if_proxy_detected(), + 'bT': 0, + 'tR': 0, + 'nR': 0, + 'aF': 1, + 'rF': 0, + 'fsT': 2, + 'fsI': 1} + ) + + @pytest.mark.asyncio + async def test_record_counters(self): + storage = await InMemoryTelemetryStorageAsync.create() + + await storage.record_ready_time(10) + assert(storage._tel_config._time_until_ready == 10) + + await storage.add_tag('tag') + assert('tag' in storage._tags) + [await storage.add_tag('tag') for i in range(1, 25)] + assert(len(storage._tags) == 10) + + await storage.record_bur_time_out() + await storage.record_bur_time_out() + assert(await storage._tel_config.get_bur_time_outs() == 2) + + await storage.record_not_ready_usage() + await storage.record_not_ready_usage() + assert(await storage._tel_config.get_non_ready_usage() == 2) + + await storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT) + assert(storage._method_exceptions._treatment == 1) + + await storage.record_impression_stats(ModelTelemetry.CounterConstants.IMPRESSIONS_QUEUED, 5) + assert(await storage._counters.get_counter_stats(ModelTelemetry.CounterConstants.IMPRESSIONS_QUEUED) == 5) + + await storage.record_event_stats(ModelTelemetry.CounterConstants.EVENTS_DROPPED, 6) + assert(await storage._counters.get_counter_stats(ModelTelemetry.CounterConstants.EVENTS_DROPPED) == 6) + + await storage.record_successful_sync(ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT, 10) + assert(storage._last_synchronization._segment == 10) + + await storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT, '500') + assert(storage._http_sync_errors._segment['500'] == 1) + + await storage.record_auth_rejections() + await storage.record_auth_rejections() + assert(await storage._counters.pop_auth_rejections() == 2) + + await storage.record_token_refreshes() + await storage.record_token_refreshes() + assert(await storage._counters.pop_token_refreshes() == 2) + + await storage.record_streaming_event((ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED, 'split', 1234)) + assert(await storage._streaming_events.pop_streaming_events() == {'streamingEvents': [{'e': ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED.value, 'd': 'split', 't': 1234}]}) + [await storage.record_streaming_event((ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED, 'split', 1234)) for i in range(1, 25)] + assert(len(storage._streaming_events._streaming_events) == 20) + + await storage.record_session_length(20) + assert(await storage._counters.get_session_length() == 20) + + @pytest.mark.asyncio + async def test_record_latencies(self): + storage = await InMemoryTelemetryStorageAsync.create() + + for method in ModelTelemetry.MethodExceptionsAndLatencies: + if self._get_method_latency(method, storage) == None: + continue + await storage.record_latency(method, 50) + assert(self._get_method_latency(method, storage)[ModelTelemetry.get_latency_bucket_index(50)] == 1) + await storage.record_latency(method, 50000000) + assert(self._get_method_latency(method, storage)[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + for j in range(10): + latency = random.randint(1001, 4987885) + current_count = self._get_method_latency(method, storage)[ModelTelemetry.get_latency_bucket_index(latency)] + [await storage.record_latency(method, latency) for i in range(2)] + assert(self._get_method_latency(method, storage)[ModelTelemetry.get_latency_bucket_index(latency)] == 2 + current_count) + + for resource in ModelTelemetry.HTTPExceptionsAndLatencies: + if self._get_http_latency(resource, storage) == None: + continue + await storage.record_sync_latency(resource, 50) + assert(self._get_http_latency(resource, storage)[ModelTelemetry.get_latency_bucket_index(50)] == 1) + await storage.record_sync_latency(resource, 50000000) + assert(self._get_http_latency(resource, storage)[ModelTelemetry.get_latency_bucket_index(50000000)] == 1) + for j in range(10): + latency = random.randint(1001, 4987885) + current_count = self._get_http_latency(resource, storage)[ModelTelemetry.get_latency_bucket_index(latency)] + [await storage.record_sync_latency(resource, latency) for i in range(2)] + assert(self._get_http_latency(resource, storage)[ModelTelemetry.get_latency_bucket_index(latency)] == 2 + current_count) + + def _get_method_latency(self, resource, storage): + if resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT: + return storage._method_latencies._treatment + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS: + return storage._method_latencies._treatments + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG: + return storage._method_latencies._treatment_with_config + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG: + return storage._method_latencies._treatments_with_config + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET: + return storage._method_latencies._treatments_by_flag_set + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS: + return storage._method_latencies._treatments_by_flag_sets + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET: + return storage._method_latencies._treatments_with_config_by_flag_set + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS: + return storage._method_latencies._treatments_with_config_by_flag_sets + elif resource == ModelTelemetry.MethodExceptionsAndLatencies.TRACK: + return storage._method_latencies._track + else: + return + + def _get_http_latency(self, resource, storage): + if resource == ModelTelemetry.HTTPExceptionsAndLatencies.SPLIT: + return storage._http_latencies._split + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT: + return storage._http_latencies._segment + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION: + return storage._http_latencies._impression + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION_COUNT: + return storage._http_latencies._impression_count + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.EVENT: + return storage._http_latencies._event + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.TELEMETRY: + return storage._http_latencies._telemetry + elif resource == ModelTelemetry.HTTPExceptionsAndLatencies.TOKEN: + return storage._http_latencies._token + else: + return + + @pytest.mark.asyncio + async def test_pop_counters(self): + storage = await InMemoryTelemetryStorageAsync.create() + + [await storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT) for i in range(2)] + await storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS) + await storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG) + [await storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG) for i in range(5)] + [await storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET) for i in range(3)] + [await storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS) for i in range(10)] + [await storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET) for i in range(7)] + [await storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS) for i in range(6)] + [await storage.record_exception(ModelTelemetry.MethodExceptionsAndLatencies.TRACK) for i in range(3)] + exceptions = await storage.pop_exceptions() + assert(storage._method_exceptions._treatment == 0) + assert(storage._method_exceptions._treatments == 0) + assert(storage._method_exceptions._treatment_with_config == 0) + assert(storage._method_exceptions._treatments_with_config == 0) + assert(storage._method_exceptions._treatments_by_flag_set == 0) + assert(storage._method_exceptions._treatments_by_flag_sets == 0) + assert(storage._method_exceptions._track == 0) + assert(storage._method_exceptions._treatments_with_config_by_flag_set == 0) + assert(storage._method_exceptions._treatments_with_config_by_flag_sets == 0) + assert(exceptions == {'methodExceptions': {'treatment': 2, 'treatments': 1, 'treatment_with_config': 1, 'treatments_with_config': 5, 'treatments_by_flag_set': 3, 'treatments_by_flag_sets': 10, 'treatments_with_config_by_flag_set': 7, 'treatments_with_config_by_flag_sets': 6, 'track': 3}}) + + await storage.add_tag('tag1') + await storage.add_tag('tag2') + tags = await storage.pop_tags() + assert(storage._tags == []) + assert(tags == ['tag1', 'tag2']) + + [await storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT, str(i)) for i in [500, 501, 502]] + [await storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.SPLIT, str(i)) for i in [400, 401, 402]] + await storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION, '502') + [await storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION_COUNT, str(i)) for i in [501, 502]] + await storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.EVENT, '501') + await storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.TELEMETRY, '505') + [await storage.record_sync_error(ModelTelemetry.HTTPExceptionsAndLatencies.TOKEN, '502') for i in range(5)] + http_errors = await storage.pop_http_errors() + assert(http_errors == {'httpErrors': {'split': {'400': 1, '401': 1, '402': 1}, 'segment': {'500': 1, '501': 1, '502': 1}, + 'impression': {'502': 1}, 'impressionCount': {'501': 1, '502': 1}, + 'event': {'501': 1}, 'telemetry': {'505': 1}, 'token': {'502': 5}}}) + assert(storage._http_sync_errors._split == {}) + assert(storage._http_sync_errors._segment == {}) + assert(storage._http_sync_errors._impression == {}) + assert(storage._http_sync_errors._impression_count == {}) + assert(storage._http_sync_errors._event == {}) + assert(storage._http_sync_errors._telemetry == {}) + + await storage.record_auth_rejections() + await storage.record_auth_rejections() + auth_rejections = await storage.pop_auth_rejections() + assert(storage._counters._auth_rejections == 0) + assert(auth_rejections == 2) + + await storage.record_token_refreshes() + await storage.record_token_refreshes() + token_refreshes = await storage.pop_token_refreshes() + assert(storage._counters._token_refreshes == 0) + assert(token_refreshes == 2) + + await storage.record_streaming_event((ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED, 'split', 1234)) + await storage.record_streaming_event((ModelTelemetry.StreamingEventTypes.OCCUPANCY_PRI, 'split', 1234)) + streaming_events = await storage.pop_streaming_events() + assert(storage._streaming_events._streaming_events == []) + assert(streaming_events == {'streamingEvents': [{'e': ModelTelemetry.StreamingEventTypes.CONNECTION_ESTABLISHED.value, 'd': 'split', 't': 1234}, + {'e': ModelTelemetry.StreamingEventTypes.OCCUPANCY_PRI.value, 'd': 'split', 't': 1234}]}) + + @pytest.mark.asyncio + async def test_pop_latencies(self): + storage = await InMemoryTelemetryStorageAsync.create() + + [await storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT, i) for i in [5, 10, 10, 10]] + [await storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS, i) for i in [7, 10, 14, 13]] + [await storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENT_WITH_CONFIG, i) for i in [200]] + [await storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG, i) for i in [50, 40]] + [await storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SET, i) for i in [15, 20]] + [await storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_BY_FLAG_SETS, i) for i in [14, 25]] + [await storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SET, i) for i in [100]] + [await storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TREATMENTS_WITH_CONFIG_BY_FLAG_SETS, i) for i in [50, 20]] + [await storage.record_latency(ModelTelemetry.MethodExceptionsAndLatencies.TRACK, i) for i in [1, 10, 100]] + latencies = await storage.pop_latencies() + + assert(storage._method_latencies._treatment == [0] * 23) + assert(storage._method_latencies._treatments == [0] * 23) + assert(storage._method_latencies._treatment_with_config == [0] * 23) + assert(storage._method_latencies._treatments_with_config == [0] * 23) + assert(storage._method_latencies._treatments_by_flag_set == [0] * 23) + assert(storage._method_latencies._treatments_by_flag_sets == [0] * 23) + assert(storage._method_latencies._treatments_with_config_by_flag_set == [0] * 23) + assert(storage._method_latencies._treatments_with_config_by_flag_sets == [0] * 23) + assert(storage._method_latencies._track == [0] * 23) + assert(latencies == {'methodLatencies': { + 'treatment': [4] + [0] * 22, + 'treatments': [4] + [0] * 22, + 'treatment_with_config': [1] + [0] * 22, + 'treatments_with_config': [2] + [0] * 22, + 'treatments_by_flag_set': [2] + [0] * 22, + 'treatments_by_flag_sets': [2] + [0] * 22, + 'treatments_with_config_by_flag_set': [1] + [0] * 22, + 'treatments_with_config_by_flag_sets': [2] + [0] * 22, + 'track': [3] + [0] * 22}}) + + [await storage.record_sync_latency(ModelTelemetry.HTTPExceptionsAndLatencies.SPLIT, i) for i in [50, 10, 20, 40]] + [await storage.record_sync_latency(ModelTelemetry.HTTPExceptionsAndLatencies.SEGMENT, i) for i in [70, 100, 40, 30]] + [await storage.record_sync_latency(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION, i) for i in [10, 20]] + [await storage.record_sync_latency(ModelTelemetry.HTTPExceptionsAndLatencies.IMPRESSION_COUNT, i) for i in [5, 10]] + [await storage.record_sync_latency(ModelTelemetry.HTTPExceptionsAndLatencies.EVENT, i) for i in [50, 40]] + [await storage.record_sync_latency(ModelTelemetry.HTTPExceptionsAndLatencies.TELEMETRY, i) for i in [100, 50, 160]] + [await storage.record_sync_latency(ModelTelemetry.HTTPExceptionsAndLatencies.TOKEN, i) for i in [10, 15, 100]] + sync_latency = await storage.pop_http_latencies() + + assert(storage._http_latencies._split == [0] * 23) + assert(storage._http_latencies._segment == [0] * 23) + assert(storage._http_latencies._impression == [0] * 23) + assert(storage._http_latencies._impression_count == [0] * 23) + assert(storage._http_latencies._telemetry == [0] * 23) + assert(storage._http_latencies._token == [0] * 23) + assert(sync_latency == {'httpLatencies': {'split': [4] + [0] * 22, 'segment': [4] + [0] * 22, + 'impression': [2] + [0] * 22, 'impressionCount': [2] + [0] * 22, 'event': [2] + [0] * 22, + 'telemetry': [3] + [0] * 22, 'token': [3] + [0] * 22}}) + +class InMemoryRuleBasedSegmentStorageTests(object): + """In memory rule based segment storage test cases.""" + + def test_storing_retrieving_segments(self, mocker): + """Test storing and retrieving splits works.""" + events_queue = queue.Queue() + rbs_storage = InMemoryRuleBasedSegmentStorage(events_queue) + + segment1 = mocker.Mock(spec=RuleBasedSegment) + name_property = mocker.PropertyMock() + name_property.return_value = 'some_segment' + type(segment1).name = name_property + + segment2 = mocker.Mock() + name2_prop = mocker.PropertyMock() + name2_prop.return_value = 'segment2' + type(segment2).name = name2_prop + + rbs_storage.update([segment1, segment2], [], -1) + assert rbs_storage.get('some_segment') == segment1 + assert rbs_storage.get_segment_names() == ['some_segment', 'segment2'] + assert rbs_storage.get('nonexistant_segment') is None + + rbs_storage.update([], ['some_segment'], -1) + assert rbs_storage.get('some_segment') is None + + def test_store_get_changenumber(self): + """Test that storing and retrieving change numbers works.""" + events_queue = queue.Queue() + storage = InMemoryRuleBasedSegmentStorage(events_queue) + assert storage.get_change_number() == -1 + storage.update([], [], 5) + assert storage.get_change_number() == 5 + + def test_contains(self): + raw = { + "changeNumber": 123, + "name": "segment1", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded":{ + "keys":[], + "segments":[] + }, + "conditions": [] + } + segment1 = rule_based_segments.from_raw(raw) + raw2 = copy.deepcopy(raw) + raw2["name"] = "segment2" + segment2 = rule_based_segments.from_raw(raw2) + raw3 = copy.deepcopy(raw) + raw3["name"] = "segment3" + segment3 = rule_based_segments.from_raw(raw3) + events_queue = queue.Queue() + storage = InMemoryRuleBasedSegmentStorage(events_queue) + storage.update([segment1, segment2, segment3], [], -1) + assert storage.contains(["segment1"]) + assert storage.contains(["segment1", "segment3"]) + assert not storage.contains(["segment5"]) + + def test_internal_event_notification(self, mocker): + """Test storing and retrieving splits works.""" + events_queue = queue.Queue() + rbs_storage = InMemoryRuleBasedSegmentStorage(events_queue) + + segment1 = mocker.Mock(spec=RuleBasedSegment) + name_property = mocker.PropertyMock() + name_property.return_value = 'some_segment' + type(segment1).name = name_property + + segment2 = mocker.Mock() + name2_prop = mocker.PropertyMock() + name2_prop.return_value = 'segment2' + type(segment2).name = name2_prop + + rbs_storage.update([segment1, segment2], [], -1) + event = events_queue.get() + assert event.internal_event == SdkInternalEvent.RB_SEGMENTS_UPDATED + assert event.metadata.get_type() == SdkEventType.SEGMENTS_UPDATE + assert len(event.metadata.get_names()) == 0 + + rbs_storage.update([], ['some_segment'], -1) + assert event.internal_event == SdkInternalEvent.RB_SEGMENTS_UPDATED + assert event.metadata.get_type() == SdkEventType.SEGMENTS_UPDATE + assert len(event.metadata.get_names()) == 0 + +class InMemoryRuleBasedSegmentStorageAsyncTests(object): + """In memory rule based segment storage test cases.""" + + @pytest.mark.asyncio + async def test_storing_retrieving_segments(self, mocker): + """Test storing and retrieving splits works.""" + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(asyncio.Queue()) + + segment1 = mocker.Mock(spec=RuleBasedSegment) + name_property = mocker.PropertyMock() + name_property.return_value = 'some_segment' + type(segment1).name = name_property + + segment2 = mocker.Mock() + name2_prop = mocker.PropertyMock() + name2_prop.return_value = 'segment2' + type(segment2).name = name2_prop + + await rbs_storage.update([segment1, segment2], [], -1) + assert await rbs_storage.get('some_segment') == segment1 + assert await rbs_storage.get_segment_names() == ['some_segment', 'segment2'] + assert await rbs_storage.get('nonexistant_segment') is None + + await rbs_storage.update([], ['some_segment'], -1) + assert await rbs_storage.get('some_segment') is None + + @pytest.mark.asyncio + async def test_store_get_changenumber(self): + """Test that storing and retrieving change numbers works.""" + storage = InMemoryRuleBasedSegmentStorageAsync(asyncio.Queue()) + assert await storage.get_change_number() == -1 + await storage.update([], [], 5) + assert await storage.get_change_number() == 5 + + @pytest.mark.asyncio + async def test_contains(self): + raw = { + "changeNumber": 123, + "name": "segment1", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded":{ + "keys":[], + "segments":[] + }, + "conditions": [] + } + segment1 = rule_based_segments.from_raw(raw) + raw2 = copy.deepcopy(raw) + raw2["name"] = "segment2" + segment2 = rule_based_segments.from_raw(raw2) + raw3 = copy.deepcopy(raw) + raw3["name"] = "segment3" + segment3 = rule_based_segments.from_raw(raw3) + storage = InMemoryRuleBasedSegmentStorageAsync(asyncio.Queue()) + await storage.update([segment1, segment2, segment3], [], -1) + assert await storage.contains(["segment1"]) + assert await storage.contains(["segment1", "segment3"]) + assert not await storage.contains(["segment5"]) + + @pytest.mark.asyncio + async def test_internal_event_notification(self, mocker): + """Test storing and retrieving splits works.""" + events_queue = asyncio.Queue() + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(events_queue) + + segment1 = mocker.Mock(spec=RuleBasedSegment) + name_property = mocker.PropertyMock() + name_property.return_value = 'some_segment' + type(segment1).name = name_property + + segment2 = mocker.Mock() + name2_prop = mocker.PropertyMock() + name2_prop.return_value = 'segment2' + type(segment2).name = name2_prop + + await rbs_storage.update([segment1, segment2], [], -1) + event = await events_queue.get() + assert event.internal_event == SdkInternalEvent.RB_SEGMENTS_UPDATED + assert event.metadata.get_type() == SdkEventType.SEGMENTS_UPDATE + assert len(event.metadata.get_names()) == 0 + + await rbs_storage.update([], ['some_segment'], -1) + event = await events_queue.get() + assert event.internal_event == SdkInternalEvent.RB_SEGMENTS_UPDATED + assert event.metadata.get_type() == SdkEventType.SEGMENTS_UPDATE + assert len(event.metadata.get_names()) == 0 + diff --git a/tests/storage/test_pluggable.py b/tests/storage/test_pluggable.py new file mode 100644 index 00000000..8b5f9a95 --- /dev/null +++ b/tests/storage/test_pluggable.py @@ -0,0 +1,1500 @@ +"""Pluggable storage test module.""" +import json +import threading +import copy +import pytest + +from splitio.optional.loaders import asyncio +from splitio.models.splits import Split +from splitio.models import splits, segments, rule_based_segments +from splitio.models.segments import Segment +from splitio.models.impressions import Impression +from splitio.models.events import Event, EventWrapper +from splitio.storage.pluggable import PluggableSplitStorage, PluggableSegmentStorage, PluggableImpressionsStorage, PluggableEventsStorage, \ + PluggableTelemetryStorage, PluggableEventsStorageAsync, PluggableSegmentStorageAsync, PluggableImpressionsStorageAsync,\ + PluggableSplitStorageAsync, PluggableTelemetryStorageAsync, PluggableRuleBasedSegmentsStorage, PluggableRuleBasedSegmentsStorageAsync +from splitio.client.util import get_metadata, SdkMetadata +from splitio.models.telemetry import MAX_TAGS, MethodExceptionsAndLatencies, OperationMode +from tests.integration import splits_json, rbsegments_json + +class StorageMockAdapter(object): + def __init__(self): + self._keys = {} + self._expire = {} + self._lock = threading.RLock() + + def get(self, key): + with self._lock: + if key not in self._keys: + return None + return self._keys[key] + + def get_items(self, key): + with self._lock: + if key not in self._keys: + return None + return list(self._keys[key]) + + def set(self, key, value): + with self._lock: + self._keys[key] = value + + def push_items(self, key, *value): + with self._lock: + items = [] + if key in self._keys: + items = self._keys[key] + [items.append(item) for item in value] + self._keys[key] = items + return len(self._keys[key]) + + def delete(self, key): + with self._lock: + if key in self._keys: + del self._keys[key] + + def pop_items(self, key): + with self._lock: + if key not in self._keys: + return None + items = list(self._keys[key]) + del self._keys[key] + return items + + def increment(self, key, value): + with self._lock: + if key not in self._keys: + self._keys[key] = 0 + self._keys[key]+= value + return self._keys[key] + + def decrement(self, key, value): + with self._lock: + if key not in self._keys: + return None + self._keys[key]-= value + return self._keys[key] + + def get_keys_by_prefix(self, prefix): + with self._lock: + keys = [] + for key in self._keys: + if prefix in key: + keys.append(key) + return keys + + def get_many(self, keys): + with self._lock: + returned_keys = [] + for key in self._keys: + if key in keys: + returned_keys.append(self._keys[key]) + return returned_keys + + def add_items(self, key, added_items): + with self._lock: + items = set() + if key in self._keys: + items = set(self._keys[key]) + [items.add(item) for item in added_items] + self._keys[key] = items + + def remove_items(self, key, removed_items): + with self._lock: + new_items = set() + for item in self._keys[key]: + if item not in removed_items: + new_items.add(item) + self._keys[key] = new_items + + def item_contains(self, key, item): + with self._lock: + if item in self._keys[key]: + return True + return False + + def get_items_count(self, key): + with self._lock: + if key in self._keys: + return len(self._keys[key]) + return None + + def expire(self, key, ttl): + with self._lock: + if key in self._expire: + self._expire[key] = -1 + else: + self._expire[key] = ttl + # should only be called once per key. + +class StorageMockAdapterAsync(object): + def __init__(self): + self._keys = {} + self._expire = {} + self._lock = asyncio.Lock() + + async def get(self, key): + async with self._lock: + if key not in self._keys: + return None + return self._keys[key] + + async def get_items(self, key): + async with self._lock: + if key not in self._keys: + return None + return list(self._keys[key]) + + async def set(self, key, value): + async with self._lock: + self._keys[key] = value + + async def push_items(self, key, *value): + async with self._lock: + items = [] + if key in self._keys: + items = self._keys[key] + [items.append(item) for item in value] + self._keys[key] = items + return len(self._keys[key]) + + async def delete(self, key): + async with self._lock: + if key in self._keys: + del self._keys[key] + + async def pop_items(self, key): + async with self._lock: + if key not in self._keys: + return None + items = list(self._keys[key]) + del self._keys[key] + return items + + async def increment(self, key, value): + async with self._lock: + if key not in self._keys: + self._keys[key] = 0 + self._keys[key]+= value + return self._keys[key] + + async def decrement(self, key, value): + async with self._lock: + if key not in self._keys: + return None + self._keys[key]-= value + return self._keys[key] + + async def get_keys_by_prefix(self, prefix): + async with self._lock: + keys = [] + for key in self._keys: + if prefix in key: + keys.append(key) + return keys + + async def get_many(self, keys): + async with self._lock: + returned_keys = [] + for key in self._keys: + if key in keys: + returned_keys.append(self._keys[key]) + return returned_keys + + async def add_items(self, key, added_items): + async with self._lock: + items = set() + if key in self._keys: + items = set(self._keys[key]) + [items.add(item) for item in added_items] + self._keys[key] = items + + async def remove_items(self, key, removed_items): + async with self._lock: + new_items = set() + for item in self._keys[key]: + if item not in removed_items: + new_items.add(item) + self._keys[key] = new_items + + async def item_contains(self, key, item): + async with self._lock: + if item in self._keys[key]: + return True + return False + + async def get_items_count(self, key): + async with self._lock: + if key in self._keys: + return len(self._keys[key]) + return None + + async def expire(self, key, ttl): + async with self._lock: + if key in self._expire: + self._expire[key] = -1 + else: + self._expire[key] = ttl + + +class PluggableSplitStorageTests(object): + """In memory split storage test cases.""" + + def setup_method(self): + """Prepare storages with test data.""" + self.mock_adapter = StorageMockAdapter() + + def test_init(self): + for sprefix in [None, 'myprefix']: + pluggable_split_storage = PluggableSplitStorage(self.mock_adapter, prefix=sprefix) + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + assert(pluggable_split_storage._prefix == prefix + "SPLITIO.split.{feature_flag_name}") + assert(pluggable_split_storage._traffic_type_prefix == prefix + "SPLITIO.trafficType.{traffic_type_name}") + assert(pluggable_split_storage._feature_flag_till_prefix == prefix + "SPLITIO.splits.till") + + # TODO: To be added when producer mode is aupported +# def test_put_many(self): +# split1 = splits.from_raw(splits_json['splitChange1_2']['splits'][0]) +# split2_temp = splits_json['splitChange1_2']['splits'][0].copy() +# split2_temp['name'] = 'another_split' +# split2 = splits.from_raw(split2_temp) +# change_number = splits_json['splitChange1_2']['till'] +# traffic_type = splits_json['splitChange1_2']['splits'][0]['trafficTypeName'] +# +# self.pluggable_split_storage.put_many([split1, split2], change_number) +# assert (self.mock_adapter._keys['myprefix.SPLITIO.split.' + split1.name] == split1.to_json()) +# assert (self.mock_adapter._keys['myprefix.SPLITIO.split.' + split2.name] == split2.to_json()) +# assert (self.mock_adapter._keys['myprefix.SPLITIO.trafficType.' + traffic_type] == 2) +# assert (self.mock_adapter._keys["myprefix.SPLITIO.splits.till"] == change_number) + + def test_get(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_split_storage = PluggableSplitStorage(self.mock_adapter, prefix=sprefix) + + split1 = splits.from_raw(splits_json['splitChange1_2']['ff']['d'][0]) + split_name = splits_json['splitChange1_2']['ff']['d'][0]['name'] + + self.mock_adapter.set(pluggable_split_storage._prefix.format(feature_flag_name=split_name), split1.to_json()) + assert(pluggable_split_storage.get(split_name).to_json() == splits.from_raw(splits_json['splitChange1_2']['ff']['d'][0]).to_json()) + assert(pluggable_split_storage.get('not_existing') == None) + + def test_fetch_many(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_split_storage = PluggableSplitStorage(self.mock_adapter, prefix=sprefix) + split1 = splits.from_raw(splits_json['splitChange1_2']['ff']['d'][0]) + split2_temp = splits_json['splitChange1_2']['ff']['d'][0].copy() + split2_temp['name'] = 'another_split' + split2 = splits.from_raw(split2_temp) + + self.mock_adapter.set(pluggable_split_storage._prefix.format(feature_flag_name=split1.name), split1.to_json()) + self.mock_adapter.set(pluggable_split_storage._prefix.format(feature_flag_name=split2.name), split2.to_json()) + fetched = pluggable_split_storage.fetch_many([split1.name, split2.name]) + assert(fetched[split1.name].to_json() == split1.to_json()) + assert(fetched[split2.name].to_json() == split2.to_json()) + + # TODO: To be added when producer mode is aupported +# def test_remove(self): +# self.mock_adapter._keys = {} +# split1 = splits.from_raw(splits_json['splitChange1_2']['splits'][0]) +# change_number = splits_json['splitChange1_2']['till'] +# split_name = splits_json['splitChange1_2']['splits'][0]['name'] +# traffic_type = splits_json['splitChange1_2']['splits'][0]['trafficTypeName'] +# +# self.pluggable_split_storage.put_many([split1], change_number) +# assert(self.pluggable_split_storage.traffic_type_exists(traffic_type) == True) +# self.pluggable_split_storage.remove(split1.name) +# assert(self.pluggable_split_storage.get(split_name) == None) +# assert(self.pluggable_split_storage.traffic_type_exists(traffic_type) == False) + + def test_get_change_number(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_split_storage = PluggableSplitStorage(self.mock_adapter, prefix=sprefix) + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + self.mock_adapter.set(prefix + "SPLITIO.splits.till", 1234) + assert(pluggable_split_storage.get_change_number() == 1234) + + def test_get_split_names(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_split_storage = PluggableSplitStorage(self.mock_adapter, prefix=sprefix) + split1 = splits.from_raw(splits_json['splitChange1_2']['ff']['d'][0]) + split2_temp = splits_json['splitChange1_2']['ff']['d'][0].copy() + split2_temp['name'] = 'another_split' + split2 = splits.from_raw(split2_temp) + self.mock_adapter.set(pluggable_split_storage._prefix.format(feature_flag_name=split1.name), split1.to_json()) + self.mock_adapter.set(pluggable_split_storage._prefix.format(feature_flag_name=split2.name), split2.to_json()) + assert(pluggable_split_storage.get_split_names() == [split1.name, split2.name]) + + # TODO: To be added when producer mode is aupported +# def test_kill_locally(self): +# self.mock_adapter._keys = {} +# split_temp = splits_json['splitChange1_2']['splits'][0] +# split_temp['killed'] = False +# split1 = splits.from_raw(split_temp) +# split_name = splits_json['splitChange1_2']['splits'][0]['name'] +# +# self.pluggable_split_storage.put_many([split1], 123) +# + # should not apply if change number is lower +# self.pluggable_split_storage.kill_locally(split_name, "off", 12) +# assert(self.pluggable_split_storage.get(split_name).killed == False) +# +# self.pluggable_split_storage.kill_locally(split_name, "off", 124) +# assert(self.pluggable_split_storage.get(split_name).killed == True) + + # TODO: To be added when producer mode is aupported +# def test_traffic_type_count(self): +# self.mock_adapter._keys = {} +# self.pluggable_split_storage._increase_traffic_type_count('user') +# assert(self.pluggable_split_storage.is_valid_traffic_type('user')) +# +# self.pluggable_split_storage._increase_traffic_type_count('user') +# assert(self.mock_adapter._keys['myprefix.SPLITIO.trafficType.user'] == 2) +# +# self.pluggable_split_storage._decrease_traffic_type_count('user') +# assert(self.mock_adapter._keys['myprefix.SPLITIO.trafficType.user'] == 1) +# +# self.pluggable_split_storage._decrease_traffic_type_count('user') +# assert(not self.pluggable_split_storage.is_valid_traffic_type('user')) + + # TODO: To be added when producer mode is aupported +# def test_put(self): +# self.mock_adapter._keys = {} +# split = splits.from_raw(splits_json['splitChange1_2']['splits'][0]) +# self.pluggable_split_storage.put(split) +# assert(self.mock_adapter._keys['myprefix.SPLITIO.trafficType.user'] == 1) +# assert(split.to_json() == self.mock_adapter.get('myprefix.SPLITIO.split.' + split.name)) +# + # changing traffic type should delete existing one and add new one +# split._traffic_type_name = 'account' +# self.pluggable_split_storage.put(split) +# assert('myprefix.SPLITIO.trafficType.user' not in self.mock_adapter._keys) +# assert(self.mock_adapter._keys['myprefix.SPLITIO.trafficType.account'] == 1) +# + # making update without changing traffic type should not increase the count +# split._killed = 'False' +# self.pluggable_split_storage.put(split) +# assert(self.mock_adapter._keys['myprefix.SPLITIO.trafficType.account'] == 1) +# assert(split.to_json()['killed'] == self.mock_adapter.get('myprefix.SPLITIO.split.' + split.name)['killed']) + + +class PluggableSplitStorageAsyncTests(object): + """In memory async split storage test cases.""" + + def setup_method(self): + """Prepare storages with test data.""" + self.mock_adapter = StorageMockAdapterAsync() + + def test_init(self): + for sprefix in [None, 'myprefix']: + pluggable_split_storage = PluggableSplitStorageAsync(self.mock_adapter, prefix=sprefix) + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + assert(pluggable_split_storage._prefix == prefix + "SPLITIO.split.{feature_flag_name}") + assert(pluggable_split_storage._traffic_type_prefix == prefix + "SPLITIO.trafficType.{traffic_type_name}") + assert(pluggable_split_storage._feature_flag_till_prefix == prefix + "SPLITIO.splits.till") + + @pytest.mark.asyncio + async def test_get(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_split_storage = PluggableSplitStorageAsync(self.mock_adapter, prefix=sprefix) + + split1 = splits.from_raw(splits_json['splitChange1_2']['ff']['d'][0]) + split_name = splits_json['splitChange1_2']['ff']['d'][0]['name'] + + await self.mock_adapter.set(pluggable_split_storage._prefix.format(feature_flag_name=split_name), split1.to_json()) + split = await pluggable_split_storage.get(split_name) + assert(split.to_json() == splits.from_raw(splits_json['splitChange1_2']['ff']['d'][0]).to_json()) + assert(await pluggable_split_storage.get('not_existing') == None) + + @pytest.mark.asyncio + async def test_fetch_many(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_split_storage = PluggableSplitStorageAsync(self.mock_adapter, prefix=sprefix) + split1 = splits.from_raw(splits_json['splitChange1_2']['ff']['d'][0]) + split2_temp = splits_json['splitChange1_2']['ff']['d'][0].copy() + split2_temp['name'] = 'another_split' + split2 = splits.from_raw(split2_temp) + + await self.mock_adapter.set(pluggable_split_storage._prefix.format(feature_flag_name=split1.name), split1.to_json()) + await self.mock_adapter.set(pluggable_split_storage._prefix.format(feature_flag_name=split2.name), split2.to_json()) + fetched = await pluggable_split_storage.fetch_many([split1.name, split2.name]) + assert(fetched[split1.name].to_json() == split1.to_json()) + assert(fetched[split2.name].to_json() == split2.to_json()) + + @pytest.mark.asyncio + async def test_get_change_number(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_split_storage = PluggableSplitStorageAsync(self.mock_adapter, prefix=sprefix) + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + await self.mock_adapter.set(prefix + "SPLITIO.splits.till", 1234) + assert(await pluggable_split_storage.get_change_number() == 1234) + + @pytest.mark.asyncio + async def test_get_split_names(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_split_storage = PluggableSplitStorageAsync(self.mock_adapter, prefix=sprefix) + split1 = splits.from_raw(splits_json['splitChange1_2']['ff']['d'][0]) + split2_temp = splits_json['splitChange1_2']['ff']['d'][0].copy() + split2_temp['name'] = 'another_split' + split2 = splits.from_raw(split2_temp) + await self.mock_adapter.set(pluggable_split_storage._prefix.format(feature_flag_name=split1.name), split1.to_json()) + await self.mock_adapter.set(pluggable_split_storage._prefix.format(feature_flag_name=split2.name), split2.to_json()) + + assert(await pluggable_split_storage.get_split_names() == [split1.name, split2.name]) + +class PluggableSegmentStorageTests(object): + """In memory split storage test cases.""" + + def setup_method(self): + """Prepare storages with test data.""" + self.mock_adapter = StorageMockAdapter() + + def test_init(self): + for sprefix in [None, 'myprefix']: + pluggable_segment_storage = PluggableSegmentStorage(self.mock_adapter, prefix=sprefix) + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + assert(pluggable_segment_storage._prefix == prefix + "SPLITIO.segment.{segment_name}") + assert(pluggable_segment_storage._segment_till_prefix == prefix + "SPLITIO.segment.{segment_name}.till") + + # TODO: to be added when get_keys() is added +# def test_update(self): +# self.mock_adapter.set(self.pluggable_segment_storage._prefix.format(segment_name='segment1'), {'key1', 'key2'}) +# self.mock_adapter.set(self.pluggable_segment_storage._segment_till_prefix.format(segment_name='segment1'), 123) +# +# assert('myprefix.SPLITIO.segment.segment1' in self.mock_adapter._keys) +# assert(self.mock_adapter._keys['myprefix.SPLITIO.segment.segment1'] == set(['key1', 'key2'])) +# assert(self.mock_adapter._keys['myprefix.SPLITIO.segment.segment1.till'] == 123) + + def test_get_change_number(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_segment_storage = PluggableSegmentStorage(self.mock_adapter, prefix=sprefix) + assert(pluggable_segment_storage.get_change_number('segment1') is None) + + self.mock_adapter.set(pluggable_segment_storage._segment_till_prefix.format(segment_name='segment1'), 123) + assert(pluggable_segment_storage.get_change_number('segment1') == 123) + + # TODO: To be added when producer mode is implemented +# self.pluggable_segment_storage.set_change_number('segment1', 124) +# assert(self.mock_adapter._keys['myprefix.SPLITIO.segment.segment1.till'] == 124) + + def test_get_segment_names(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_segment_storage = PluggableSegmentStorage(self.mock_adapter, prefix=sprefix) + assert(pluggable_segment_storage.get_segment_names() == []) + + self.mock_adapter.set(pluggable_segment_storage._prefix.format(segment_name='segment1'), {'key1', 'key2'}) + self.mock_adapter.set(pluggable_segment_storage._prefix.format(segment_name='segment2'), {}) + self.mock_adapter.set(pluggable_segment_storage._prefix.format(segment_name='segment3'), {'key1', 'key5'}) + assert(pluggable_segment_storage.get_segment_names() == ['segment1', 'segment2', 'segment3']) + + # TODO: to be added when get_keys() is added +# def test_get_keys(self): +# self.mock_adapter._keys = {} +# self.pluggable_segment_storage.update('segment1', ['key1', 'key2'], [], 123) +# assert(self.pluggable_segment_storage.get_keys('segment1').sort() == ['key1', 'key2'].sort()) + + def test_segment_contains(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_segment_storage = PluggableSegmentStorage(self.mock_adapter, prefix=sprefix) + self.mock_adapter.set(pluggable_segment_storage._prefix.format(segment_name='segment1'), {'key1', 'key2'}) + assert(not pluggable_segment_storage.segment_contains('segment1', 'key5')) + assert(pluggable_segment_storage.segment_contains('segment1', 'key1')) + + # TODO: To be added when producer mode is implemented +# def get_segment_keys_count(self): +# self.mock_adapter._keys = {} +# self.pluggable_segment_storage.update('segment1', ['key1', 'key2'], [], 123) +# self.pluggable_segment_storage.update('segment2', [], [], 123) +# self.pluggable_segment_storage.update('segment3', ['key1', 'key5'], [], 123) +# assert(self.pluggable_segment_storage.get_segment_keys_count() == 4) + + def test_get(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_segment_storage = PluggableSegmentStorage(self.mock_adapter, prefix=sprefix) + self.mock_adapter.set(pluggable_segment_storage._prefix.format(segment_name='segment1'), {'key1', 'key2'}) + segment = pluggable_segment_storage.get('segment1') + assert(segment.name == 'segment1') + assert(segment.keys == {'key1', 'key2'}) + + # TODO: To be added when producer mode is implemented +# def test_put(self): +# self.mock_adapter._keys = {} +# self.pluggable_segment_storage.update('segment1', ['key1', 'key2'], [], 123) +# segment = self.pluggable_segment_storage.get('segment1') +# segment._name = 'segment2' +# segment._keys.add('key3') +# +# self.pluggable_segment_storage.put(segment) +# assert('myprefix.SPLITIO.segment.segment2' in self.mock_adapter._keys) +# assert(self.mock_adapter._keys['myprefix.SPLITIO.segment.segment2'] == {'key1', 'key2', 'key3'}) +# assert(self.mock_adapter._keys['myprefix.SPLITIO.segment.segment2.till'] == 123) + + +class PluggableSegmentStorageAsyncTests(object): + """In memory async segment storage test cases.""" + + def setup_method(self): + """Prepare storages with test data.""" + self.mock_adapter = StorageMockAdapterAsync() + + def test_init(self): + for sprefix in [None, 'myprefix']: + pluggable_segment_storage = PluggableSegmentStorageAsync(self.mock_adapter, prefix=sprefix) + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + assert(pluggable_segment_storage._prefix == prefix + "SPLITIO.segment.{segment_name}") + assert(pluggable_segment_storage._segment_till_prefix == prefix + "SPLITIO.segment.{segment_name}.till") + + @pytest.mark.asyncio + async def test_get_change_number(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_segment_storage = PluggableSegmentStorageAsync(self.mock_adapter, prefix=sprefix) + assert(await pluggable_segment_storage.get_change_number('segment1') is None) + + await self.mock_adapter.set(pluggable_segment_storage._segment_till_prefix.format(segment_name='segment1'), 123) + assert(await pluggable_segment_storage.get_change_number('segment1') == 123) + + @pytest.mark.asyncio + async def test_get_segment_names(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_segment_storage = PluggableSegmentStorageAsync(self.mock_adapter, prefix=sprefix) + assert(await pluggable_segment_storage.get_segment_names() == []) + + await self.mock_adapter.set(pluggable_segment_storage._prefix.format(segment_name='segment1'), {'key1', 'key2'}) + await self.mock_adapter.set(pluggable_segment_storage._prefix.format(segment_name='segment2'), {}) + await self.mock_adapter.set(pluggable_segment_storage._prefix.format(segment_name='segment3'), {'key1', 'key5'}) + assert(await pluggable_segment_storage.get_segment_names() == ['segment1', 'segment2', 'segment3']) + + @pytest.mark.asyncio + async def test_segment_contains(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_segment_storage = PluggableSegmentStorageAsync(self.mock_adapter, prefix=sprefix) + await self.mock_adapter.set(pluggable_segment_storage._prefix.format(segment_name='segment1'), {'key1', 'key2'}) + assert(not await pluggable_segment_storage.segment_contains('segment1', 'key5')) + assert(await pluggable_segment_storage.segment_contains('segment1', 'key1')) + + @pytest.mark.asyncio + async def test_get(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_segment_storage = PluggableSegmentStorageAsync(self.mock_adapter, prefix=sprefix) + await self.mock_adapter.set(pluggable_segment_storage._prefix.format(segment_name='segment1'), {'key1', 'key2'}) + segment = await pluggable_segment_storage.get('segment1') + assert(segment.name == 'segment1') + assert(segment.keys == {'key1', 'key2'}) + + +class PluggableImpressionsStorageTests(object): + """In memory impressions storage test cases.""" + + def setup_method(self): + """Prepare storages with test data.""" + self.mock_adapter = StorageMockAdapter() + self.metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + + def test_init(self): + for sprefix in [None, 'myprefix']: + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + pluggable_imp_storage = PluggableImpressionsStorage(self.mock_adapter, self.metadata, prefix=sprefix) + assert(pluggable_imp_storage._impressions_queue_key == prefix + "SPLITIO.impressions") + assert(pluggable_imp_storage._sdk_metadata == { + 's': self.metadata.sdk_version, + 'n': self.metadata.instance_name, + 'i': self.metadata.instance_ip, + }) + + + def test_put(self): + for sprefix in [None, 'myprefix']: + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + pluggable_imp_storage = PluggableImpressionsStorage(self.mock_adapter, self.metadata, prefix=sprefix) + impressions = [ + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None) + ] + assert(pluggable_imp_storage.put(impressions)) + assert(pluggable_imp_storage._impressions_queue_key in self.mock_adapter._keys) + assert(self.mock_adapter._keys[prefix + "SPLITIO.impressions"] == pluggable_imp_storage._wrap_impressions(impressions)) + assert(self.mock_adapter._expire[prefix + "SPLITIO.impressions"] == PluggableImpressionsStorage.IMPRESSIONS_KEY_DEFAULT_TTL) + + impressions2 = [ + Impression('key5', 'feature1', 'off', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key6', 'feature2', 'off', 'some_label', 123456, 'buck1', 321654, None, None), + ] + assert(pluggable_imp_storage.put(impressions2)) + assert(self.mock_adapter._keys[prefix + "SPLITIO.impressions"] == pluggable_imp_storage._wrap_impressions(impressions + impressions2)) + + def test_wrap_impressions(self): + for sprefix in [None, 'myprefix']: + pluggable_imp_storage = PluggableImpressionsStorage(self.mock_adapter, self.metadata, prefix=sprefix) + impressions = [ + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key2', 'feature2', 'off', 'some_label', 123456, 'buck1', 321654, None, None), + ] + assert(pluggable_imp_storage._wrap_impressions(impressions) == [ + json.dumps({ + 'm': { + 's': self.metadata.sdk_version, + 'n': self.metadata.instance_name, + 'i': self.metadata.instance_ip, + }, + 'i': { + 'k': 'key1', + 'b': 'buck1', + 'f': 'feature1', + 't': 'on', + 'r': 'some_label', + 'c': 123456, + 'm': 321654, + 'properties': None + } + }), + json.dumps({ + 'm': { + 's': self.metadata.sdk_version, + 'n': self.metadata.instance_name, + 'i': self.metadata.instance_ip, + }, + 'i': { + 'k': 'key2', + 'b': 'buck1', + 'f': 'feature2', + 't': 'off', + 'r': 'some_label', + 'c': 123456, + 'm': 321654, + 'properties': None + } + }) + ]) + + def test_expire_key(self): + for sprefix in [None, 'myprefix']: + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + pluggable_imp_storage = PluggableImpressionsStorage(self.mock_adapter, self.metadata, prefix=sprefix) + self.expired_called = False + self.key = "" + self.ttl = 0 + def mock_expire(impressions_queue_key, ttl): + self.key = impressions_queue_key + self.ttl = ttl + self.expired_called = True + + self.mock_adapter.expire = mock_expire + + # should not call if total_keys are higher + pluggable_imp_storage.expire_key(200, 10) + assert(not self.expired_called) + + pluggable_imp_storage.expire_key(200, 200) + assert(self.expired_called) + assert(self.key == prefix + "SPLITIO.impressions") + assert(self.ttl == pluggable_imp_storage.IMPRESSIONS_KEY_DEFAULT_TTL) + + +class PluggableImpressionsStorageAsyncTests(object): + """In memory impressions storage test cases.""" + + def setup_method(self): + """Prepare storages with test data.""" + self.mock_adapter = StorageMockAdapterAsync() + self.metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + + def test_init(self): + for sprefix in [None, 'myprefix']: + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + pluggable_imp_storage = PluggableImpressionsStorageAsync(self.mock_adapter, self.metadata, prefix=sprefix) + assert(pluggable_imp_storage._impressions_queue_key == prefix + "SPLITIO.impressions") + assert(pluggable_imp_storage._sdk_metadata == { + 's': self.metadata.sdk_version, + 'n': self.metadata.instance_name, + 'i': self.metadata.instance_ip, + }) + + @pytest.mark.asyncio + async def test_put(self): + for sprefix in [None, 'myprefix']: + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + pluggable_imp_storage = PluggableImpressionsStorageAsync(self.mock_adapter, self.metadata, prefix=sprefix) + impressions = [ + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None) + ] + assert(await pluggable_imp_storage.put(impressions)) + assert(pluggable_imp_storage._impressions_queue_key in self.mock_adapter._keys) + assert(self.mock_adapter._keys[prefix + "SPLITIO.impressions"] == pluggable_imp_storage._wrap_impressions(impressions)) + assert(self.mock_adapter._expire[prefix + "SPLITIO.impressions"] == PluggableImpressionsStorageAsync.IMPRESSIONS_KEY_DEFAULT_TTL) + + impressions2 = [ + Impression('key5', 'feature1', 'off', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key6', 'feature2', 'off', 'some_label', 123456, 'buck1', 321654, None, None), + ] + assert(await pluggable_imp_storage.put(impressions2)) + assert(self.mock_adapter._keys[prefix + "SPLITIO.impressions"] == pluggable_imp_storage._wrap_impressions(impressions + impressions2)) + + def test_wrap_impressions(self): + for sprefix in [None, 'myprefix']: + pluggable_imp_storage = PluggableImpressionsStorageAsync(self.mock_adapter, self.metadata, prefix=sprefix) + impressions = [ + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key2', 'feature2', 'off', 'some_label', 123456, 'buck1', 321654, None, None), + ] + assert(pluggable_imp_storage._wrap_impressions(impressions) == [ + json.dumps({ + 'm': { + 's': self.metadata.sdk_version, + 'n': self.metadata.instance_name, + 'i': self.metadata.instance_ip, + }, + 'i': { + 'k': 'key1', + 'b': 'buck1', + 'f': 'feature1', + 't': 'on', + 'r': 'some_label', + 'c': 123456, + 'm': 321654, + 'properties': None + } + }), + json.dumps({ + 'm': { + 's': self.metadata.sdk_version, + 'n': self.metadata.instance_name, + 'i': self.metadata.instance_ip, + }, + 'i': { + 'k': 'key2', + 'b': 'buck1', + 'f': 'feature2', + 't': 'off', + 'r': 'some_label', + 'c': 123456, + 'm': 321654, + 'properties': None + } + }) + ]) + + @pytest.mark.asyncio + async def test_expire_key(self): + for sprefix in [None, 'myprefix']: + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + pluggable_imp_storage = PluggableImpressionsStorageAsync(self.mock_adapter, self.metadata, prefix=sprefix) + self.expired_called = False + self.key = "" + self.ttl = 0 + async def mock_expire(impressions_queue_key, ttl): + self.key = impressions_queue_key + self.ttl = ttl + self.expired_called = True + + self.mock_adapter.expire = mock_expire + + # should not call if total_keys are higher + await pluggable_imp_storage.expire_key(200, 10) + assert(not self.expired_called) + + await pluggable_imp_storage.expire_key(200, 200) + assert(self.expired_called) + assert(self.key == prefix + "SPLITIO.impressions") + assert(self.ttl == pluggable_imp_storage.IMPRESSIONS_KEY_DEFAULT_TTL) + + +class PluggableEventsStorageTests(object): + """Pluggable events storage test cases.""" + + def setup_method(self): + """Prepare storages with test data.""" + self.mock_adapter = StorageMockAdapter() + self.metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + + def test_init(self): + for sprefix in [None, 'myprefix']: + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + pluggable_events_storage = PluggableEventsStorage(self.mock_adapter, self.metadata, prefix=sprefix) + assert(pluggable_events_storage._events_queue_key == prefix + "SPLITIO.events") + assert(pluggable_events_storage._sdk_metadata == { + 's': self.metadata.sdk_version, + 'n': self.metadata.instance_name, + 'i': self.metadata.instance_ip, + }) + + def test_put(self): + for sprefix in [None, 'myprefix']: + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + pluggable_events_storage = PluggableEventsStorage(self.mock_adapter, self.metadata, prefix=sprefix) + events = [ + EventWrapper(event=Event('key1', 'user', 'purchase', 10, 123456, None), size=32768), + EventWrapper(event=Event('key2', 'user', 'purchase', 10, 123456, None), size=32768), + EventWrapper(event=Event('key3', 'user', 'purchase', 10, 123456, None), size=32768), + EventWrapper(event=Event('key4', 'user', 'purchase', 10, 123456, None), size=32768), + ] + assert(pluggable_events_storage.put(events)) + assert(pluggable_events_storage._events_queue_key in self.mock_adapter._keys) + assert(self.mock_adapter._keys[prefix + "SPLITIO.events"] == pluggable_events_storage._wrap_events(events)) + assert(self.mock_adapter._expire[prefix + "SPLITIO.events"] == PluggableEventsStorage._EVENTS_KEY_DEFAULT_TTL) + + events2 = [ + EventWrapper(event=Event('key5', 'user', 'purchase', 10, 123456, None), size=32768), + EventWrapper(event=Event('key6', 'user', 'purchase', 10, 123456, None), size=32768), + ] + assert(pluggable_events_storage.put(events2)) + assert(self.mock_adapter._keys[prefix + "SPLITIO.events"] == pluggable_events_storage._wrap_events(events + events2)) + + def test_wrap_events(self): + for sprefix in [None, 'myprefix']: + pluggable_events_storage = PluggableEventsStorage(self.mock_adapter, self.metadata, prefix=sprefix) + events = [ + EventWrapper(event=Event('key1', 'user', 'purchase', 10, 123456, None), size=32768), + EventWrapper(event=Event('key2', 'user', 'purchase', 10, 123456, None), size=32768), + ] + assert(pluggable_events_storage._wrap_events(events) == [ + json.dumps({ + 'e': { + 'key': 'key1', + 'trafficTypeName': 'user', + 'eventTypeId': 'purchase', + 'value': 10, + 'timestamp': 123456, + 'properties': None, + }, + 'm': { + 's': self.metadata.sdk_version, + 'n': self.metadata.instance_name, + 'i': self.metadata.instance_ip, + } + }), + json.dumps({ + 'e': { + 'key': 'key2', + 'trafficTypeName': 'user', + 'eventTypeId': 'purchase', + 'value': 10, + 'timestamp': 123456, + 'properties': None, + }, + 'm': { + 's': self.metadata.sdk_version, + 'n': self.metadata.instance_name, + 'i': self.metadata.instance_ip, + } + }) + ]) + + def test_expire_key(self): + for sprefix in [None, 'myprefix']: + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + pluggable_events_storage = PluggableEventsStorage(self.mock_adapter, self.metadata, prefix=sprefix) + self.expired_called = False + self.key = "" + self.ttl = 0 + def mock_expire(impressions_event_key, ttl): + self.key = impressions_event_key + self.ttl = ttl + self.expired_called = True + + self.mock_adapter.expire = mock_expire + + # should not call if total_keys are higher + pluggable_events_storage.expire_key(200, 10) + assert(not self.expired_called) + + pluggable_events_storage.expire_key(200, 200) + assert(self.expired_called) + assert(self.key == prefix + "SPLITIO.events") + assert(self.ttl == pluggable_events_storage._EVENTS_KEY_DEFAULT_TTL) + + +class PluggableEventsStorageAsyncTests(object): + """Pluggable events storage test cases.""" + + def setup_method(self): + """Prepare storages with test data.""" + self.mock_adapter = StorageMockAdapterAsync() + self.metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + + def test_init(self): + for sprefix in [None, 'myprefix']: + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + pluggable_events_storage = PluggableEventsStorageAsync(self.mock_adapter, self.metadata, prefix=sprefix) + assert(pluggable_events_storage._events_queue_key == prefix + "SPLITIO.events") + assert(pluggable_events_storage._sdk_metadata == { + 's': self.metadata.sdk_version, + 'n': self.metadata.instance_name, + 'i': self.metadata.instance_ip, + }) + + @pytest.mark.asyncio + async def test_put(self): + for sprefix in [None, 'myprefix']: + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + pluggable_events_storage = PluggableEventsStorageAsync(self.mock_adapter, self.metadata, prefix=sprefix) + events = [ + EventWrapper(event=Event('key1', 'user', 'purchase', 10, 123456, None), size=32768), + EventWrapper(event=Event('key2', 'user', 'purchase', 10, 123456, None), size=32768), + EventWrapper(event=Event('key3', 'user', 'purchase', 10, 123456, None), size=32768), + EventWrapper(event=Event('key4', 'user', 'purchase', 10, 123456, None), size=32768), + ] + assert(await pluggable_events_storage.put(events)) + assert(pluggable_events_storage._events_queue_key in self.mock_adapter._keys) + assert(self.mock_adapter._keys[prefix + "SPLITIO.events"] == pluggable_events_storage._wrap_events(events)) + assert(self.mock_adapter._expire[prefix + "SPLITIO.events"] == PluggableEventsStorageAsync._EVENTS_KEY_DEFAULT_TTL) + + events2 = [ + EventWrapper(event=Event('key5', 'user', 'purchase', 10, 123456, None), size=32768), + EventWrapper(event=Event('key6', 'user', 'purchase', 10, 123456, None), size=32768), + ] + assert(await pluggable_events_storage.put(events2)) + assert(self.mock_adapter._keys[prefix + "SPLITIO.events"] == pluggable_events_storage._wrap_events(events + events2)) + + def test_wrap_events(self): + for sprefix in [None, 'myprefix']: + pluggable_events_storage = PluggableEventsStorageAsync(self.mock_adapter, self.metadata, prefix=sprefix) + events = [ + EventWrapper(event=Event('key1', 'user', 'purchase', 10, 123456, None), size=32768), + EventWrapper(event=Event('key2', 'user', 'purchase', 10, 123456, None), size=32768), + ] + assert(pluggable_events_storage._wrap_events(events) == [ + json.dumps({ + 'e': { + 'key': 'key1', + 'trafficTypeName': 'user', + 'eventTypeId': 'purchase', + 'value': 10, + 'timestamp': 123456, + 'properties': None, + }, + 'm': { + 's': self.metadata.sdk_version, + 'n': self.metadata.instance_name, + 'i': self.metadata.instance_ip, + } + }), + json.dumps({ + 'e': { + 'key': 'key2', + 'trafficTypeName': 'user', + 'eventTypeId': 'purchase', + 'value': 10, + 'timestamp': 123456, + 'properties': None, + }, + 'm': { + 's': self.metadata.sdk_version, + 'n': self.metadata.instance_name, + 'i': self.metadata.instance_ip, + } + }) + ]) + + @pytest.mark.asyncio + async def test_expire_key(self): + for sprefix in [None, 'myprefix']: + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + pluggable_events_storage = PluggableEventsStorageAsync(self.mock_adapter, self.metadata, prefix=sprefix) + self.expired_called = False + self.key = "" + self.ttl = 0 + async def mock_expire(impressions_event_key, ttl): + self.key = impressions_event_key + self.ttl = ttl + self.expired_called = True + + self.mock_adapter.expire = mock_expire + + # should not call if total_keys are higher + await pluggable_events_storage.expire_key(200, 10) + assert(not self.expired_called) + + await pluggable_events_storage.expire_key(200, 200) + assert(self.expired_called) + assert(self.key == prefix + "SPLITIO.events") + assert(self.ttl == pluggable_events_storage._EVENTS_KEY_DEFAULT_TTL) + + +class PluggableTelemetryStorageTests(object): + """Pluggable telemetry storage test cases.""" + + def setup_method(self): + """Prepare storages with test data.""" + self.mock_adapter = StorageMockAdapter() + self.sdk_metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + + def test_init(self): + for sprefix in [None, 'myprefix']: + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + pluggable_telemetry_storage = PluggableTelemetryStorage(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + assert(pluggable_telemetry_storage._telemetry_config_key == prefix + 'SPLITIO.telemetry.init') + assert(pluggable_telemetry_storage._telemetry_latencies_key == prefix + 'SPLITIO.telemetry.latencies') + assert(pluggable_telemetry_storage._telemetry_exceptions_key == prefix + 'SPLITIO.telemetry.exceptions') + assert(pluggable_telemetry_storage._sdk_metadata == self.sdk_metadata.sdk_version + '/' + self.sdk_metadata.instance_name + '/' + self.sdk_metadata.instance_ip) + assert(pluggable_telemetry_storage._config_tags == []) + + def test_reset_config_tags(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = PluggableTelemetryStorage(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + pluggable_telemetry_storage._config_tags = ['a'] + pluggable_telemetry_storage._reset_config_tags() + assert(pluggable_telemetry_storage._config_tags == []) + + def test_add_config_tag(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = PluggableTelemetryStorage(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + pluggable_telemetry_storage.add_config_tag('q') + assert(pluggable_telemetry_storage._config_tags == ['q']) + + pluggable_telemetry_storage._config_tags = [] + for i in range(0, 20): + pluggable_telemetry_storage.add_config_tag('q' + str(i)) + assert(len(pluggable_telemetry_storage._config_tags) == MAX_TAGS) + assert(pluggable_telemetry_storage._config_tags == ['q' + str(i) for i in range(0, MAX_TAGS)]) + + def test_record_config(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = PluggableTelemetryStorage(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + self.config = {} + self.extra_config = {} + def record_config_mock(config, extra_config, af, inf): + self.config = config + self.extra_config = extra_config + + pluggable_telemetry_storage._tel_config.record_config = record_config_mock + pluggable_telemetry_storage.record_config({'item': 'value'}, {'item2': 'value2'}, 0, 0) + assert(self.config == {'item': 'value'}) + assert(self.extra_config == {'item2': 'value2'}) + + def test_pop_config_tags(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = PluggableTelemetryStorage(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + pluggable_telemetry_storage._config_tags = ['a'] + pluggable_telemetry_storage.pop_config_tags() + assert(pluggable_telemetry_storage._config_tags == []) + + def test_record_active_and_redundant_factories(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = PluggableTelemetryStorage(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + self.active_factory_count = 0 + self.redundant_factory_count = 0 + def record_active_and_redundant_factories_mock(active_factory_count, redundant_factory_count): + self.active_factory_count = active_factory_count + self.redundant_factory_count = redundant_factory_count + + pluggable_telemetry_storage._tel_config.record_active_and_redundant_factories = record_active_and_redundant_factories_mock + pluggable_telemetry_storage.record_active_and_redundant_factories(2, 1) + assert(self.active_factory_count == 2) + assert(self.redundant_factory_count == 1) + + def test_record_latency(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = PluggableTelemetryStorage(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + def expire_keys_mock(*args, **kwargs): + assert(args[0] == pluggable_telemetry_storage._telemetry_latencies_key + '::python-1.1.1/hostname/ip/treatment/0') + assert(args[1] == pluggable_telemetry_storage._TELEMETRY_KEY_DEFAULT_TTL) + assert(args[2] == 1) + assert(args[3] == 1) + pluggable_telemetry_storage.expire_keys = expire_keys_mock + # should increment bucket 0 + pluggable_telemetry_storage.record_latency(MethodExceptionsAndLatencies.TREATMENT, 0) + assert(self.mock_adapter._keys[pluggable_telemetry_storage._telemetry_latencies_key + '::python-1.1.1/hostname/ip/treatment/0'] == 1) + + def expire_keys_mock2(*args, **kwargs): + assert(args[0] == pluggable_telemetry_storage._telemetry_latencies_key + '::python-1.1.1/hostname/ip/treatment/3') + assert(args[1] == pluggable_telemetry_storage._TELEMETRY_KEY_DEFAULT_TTL) + assert(args[2] == 1) + assert(args[3] == 1) + pluggable_telemetry_storage.expire_keys = expire_keys_mock2 + # should increment bucket 3 + pluggable_telemetry_storage.record_latency(MethodExceptionsAndLatencies.TREATMENT, 3) + + def expire_keys_mock3(*args, **kwargs): + assert(args[0] == pluggable_telemetry_storage._telemetry_latencies_key + '::python-1.1.1/hostname/ip/treatment/3') + assert(args[1] == pluggable_telemetry_storage._TELEMETRY_KEY_DEFAULT_TTL) + assert(args[2] == 1) + assert(args[3] == 2) + pluggable_telemetry_storage.expire_keys = expire_keys_mock3 + # should increment bucket 3 + pluggable_telemetry_storage.record_latency(MethodExceptionsAndLatencies.TREATMENT, 3) + assert(self.mock_adapter._keys[pluggable_telemetry_storage._telemetry_latencies_key + '::python-1.1.1/hostname/ip/treatment/3'] == 2) + + def test_record_exception(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = PluggableTelemetryStorage(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + def expire_keys_mock(*args, **kwargs): + assert(args[0] == pluggable_telemetry_storage._telemetry_exceptions_key + '::python-1.1.1/hostname/ip/treatment') + assert(args[1] == pluggable_telemetry_storage._TELEMETRY_KEY_DEFAULT_TTL) + assert(args[2] == 1) + assert(args[3] == 1) + + pluggable_telemetry_storage.expire_keys = expire_keys_mock + pluggable_telemetry_storage.record_exception(MethodExceptionsAndLatencies.TREATMENT) + assert(self.mock_adapter._keys[pluggable_telemetry_storage._telemetry_exceptions_key + '::python-1.1.1/hostname/ip/treatment'] == 1) + + def test_push_config_stats(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = PluggableTelemetryStorage(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + pluggable_telemetry_storage.record_config( + {'operationMode': 'standalone', + 'streamingEnabled': True, + 'impressionsQueueSize': 100, + 'eventsQueueSize': 200, + 'impressionsMode': 'DEBUG','' + 'impressionListener': None, + 'featuresRefreshRate': 30, + 'segmentsRefreshRate': 30, + 'impressionsRefreshRate': 60, + 'eventsPushRate': 60, + 'metricsRefreshRate': 10, + 'storageType': None + }, {}, 0, 0 + ) + pluggable_telemetry_storage.record_active_and_redundant_factories(2, 1) + pluggable_telemetry_storage.push_config_stats() + assert(self.mock_adapter._keys[pluggable_telemetry_storage._telemetry_config_key + "::" + pluggable_telemetry_storage._sdk_metadata] == '{"aF": 2, "rF": 1, "sT": "memory", "oM": 0, "t": []}') + + +class PluggableTelemetryStorageAsyncTests(object): + """Pluggable telemetry storage test cases.""" + + def setup_method(self): + """Prepare storages with test data.""" + self.mock_adapter = StorageMockAdapterAsync() + self.sdk_metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + + @pytest.mark.asyncio + async def test_init(self): + for sprefix in [None, 'myprefix']: + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + pluggable_telemetry_storage = await PluggableTelemetryStorageAsync.create(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + assert(pluggable_telemetry_storage._telemetry_config_key == prefix + 'SPLITIO.telemetry.init') + assert(pluggable_telemetry_storage._telemetry_latencies_key == prefix + 'SPLITIO.telemetry.latencies') + assert(pluggable_telemetry_storage._telemetry_exceptions_key == prefix + 'SPLITIO.telemetry.exceptions') + assert(pluggable_telemetry_storage._sdk_metadata == self.sdk_metadata.sdk_version + '/' + self.sdk_metadata.instance_name + '/' + self.sdk_metadata.instance_ip) + assert(pluggable_telemetry_storage._config_tags == []) + + @pytest.mark.asyncio + async def test_reset_config_tags(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = await PluggableTelemetryStorageAsync.create(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + pluggable_telemetry_storage._config_tags = ['a'] + await pluggable_telemetry_storage._reset_config_tags() + assert(pluggable_telemetry_storage._config_tags == []) + + @pytest.mark.asyncio + async def test_add_config_tag(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = await PluggableTelemetryStorageAsync.create(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + await pluggable_telemetry_storage.add_config_tag('q') + assert(pluggable_telemetry_storage._config_tags == ['q']) + + pluggable_telemetry_storage._config_tags = [] + for i in range(0, 20): + await pluggable_telemetry_storage.add_config_tag('q' + str(i)) + assert(len(pluggable_telemetry_storage._config_tags) == MAX_TAGS) + assert(pluggable_telemetry_storage._config_tags == ['q' + str(i) for i in range(0, MAX_TAGS)]) + + @pytest.mark.asyncio + async def test_record_config(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = await PluggableTelemetryStorageAsync.create(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + self.config = {} + self.extra_config = {} + async def record_config_mock(config, extra_config, tf, ifs): + self.config = config + self.extra_config = extra_config + + pluggable_telemetry_storage._tel_config.record_config = record_config_mock + await pluggable_telemetry_storage.record_config({'item': 'value'}, {'item2': 'value2'}, 0, 0) + assert(self.config == {'item': 'value'}) + assert(self.extra_config == {'item2': 'value2'}) + + @pytest.mark.asyncio + async def test_pop_config_tags(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = await PluggableTelemetryStorageAsync.create(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + pluggable_telemetry_storage._config_tags = ['a'] + await pluggable_telemetry_storage.pop_config_tags() + assert(pluggable_telemetry_storage._config_tags == []) + + @pytest.mark.asyncio + async def test_record_active_and_redundant_factories(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = await PluggableTelemetryStorageAsync.create(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + self.active_factory_count = 0 + self.redundant_factory_count = 0 + async def record_active_and_redundant_factories_mock(active_factory_count, redundant_factory_count): + self.active_factory_count = active_factory_count + self.redundant_factory_count = redundant_factory_count + + pluggable_telemetry_storage._tel_config.record_active_and_redundant_factories = record_active_and_redundant_factories_mock + await pluggable_telemetry_storage.record_active_and_redundant_factories(2, 1) + assert(self.active_factory_count == 2) + assert(self.redundant_factory_count == 1) + + @pytest.mark.asyncio + async def test_record_latency(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = await PluggableTelemetryStorageAsync.create(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + async def expire_keys_mock(*args, **kwargs): + assert(args[0] == pluggable_telemetry_storage._telemetry_latencies_key + '::python-1.1.1/hostname/ip/treatment/0') + assert(args[1] == pluggable_telemetry_storage._TELEMETRY_KEY_DEFAULT_TTL) + assert(args[2] == 1) + assert(args[3] == 1) + pluggable_telemetry_storage.expire_keys = expire_keys_mock + # should increment bucket 0 + await pluggable_telemetry_storage.record_latency(MethodExceptionsAndLatencies.TREATMENT, 0) + assert(self.mock_adapter._keys[pluggable_telemetry_storage._telemetry_latencies_key + '::python-1.1.1/hostname/ip/treatment/0'] == 1) + + async def expire_keys_mock2(*args, **kwargs): + assert(args[0] == pluggable_telemetry_storage._telemetry_latencies_key + '::python-1.1.1/hostname/ip/treatment/3') + assert(args[1] == pluggable_telemetry_storage._TELEMETRY_KEY_DEFAULT_TTL) + assert(args[2] == 1) + assert(args[3] == 1) + pluggable_telemetry_storage.expire_keys = expire_keys_mock2 + # should increment bucket 3 + await pluggable_telemetry_storage.record_latency(MethodExceptionsAndLatencies.TREATMENT, 3) + + async def expire_keys_mock3(*args, **kwargs): + assert(args[0] == pluggable_telemetry_storage._telemetry_latencies_key + '::python-1.1.1/hostname/ip/treatment/3') + assert(args[1] == pluggable_telemetry_storage._TELEMETRY_KEY_DEFAULT_TTL) + assert(args[2] == 1) + assert(args[3] == 2) + pluggable_telemetry_storage.expire_keys = expire_keys_mock3 + # should increment bucket 3 + await pluggable_telemetry_storage.record_latency(MethodExceptionsAndLatencies.TREATMENT, 3) + assert(self.mock_adapter._keys[pluggable_telemetry_storage._telemetry_latencies_key + '::python-1.1.1/hostname/ip/treatment/3'] == 2) + + @pytest.mark.asyncio + async def test_record_exception(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = await PluggableTelemetryStorageAsync.create(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + async def expire_keys_mock(*args, **kwargs): + assert(args[0] == pluggable_telemetry_storage._telemetry_exceptions_key + '::python-1.1.1/hostname/ip/treatment') + assert(args[1] == pluggable_telemetry_storage._TELEMETRY_KEY_DEFAULT_TTL) + assert(args[2] == 1) + assert(args[3] == 1) + + pluggable_telemetry_storage.expire_keys = expire_keys_mock + await pluggable_telemetry_storage.record_exception(MethodExceptionsAndLatencies.TREATMENT) + assert(self.mock_adapter._keys[pluggable_telemetry_storage._telemetry_exceptions_key + '::python-1.1.1/hostname/ip/treatment'] == 1) + + @pytest.mark.asyncio + async def test_push_config_stats(self): + for sprefix in [None, 'myprefix']: + pluggable_telemetry_storage = await PluggableTelemetryStorageAsync.create(self.mock_adapter, self.sdk_metadata, prefix=sprefix) + await pluggable_telemetry_storage.record_config( + {'operationMode': 'standalone', + 'streamingEnabled': True, + 'impressionsQueueSize': 100, + 'eventsQueueSize': 200, + 'impressionsMode': 'DEBUG','' + 'impressionListener': None, + 'featuresRefreshRate': 30, + 'segmentsRefreshRate': 30, + 'impressionsRefreshRate': 60, + 'eventsPushRate': 60, + 'metricsRefreshRate': 10, + 'storageType': None + }, {}, 0, 0 + ) + await pluggable_telemetry_storage.record_active_and_redundant_factories(2, 1) + await pluggable_telemetry_storage.push_config_stats() + assert(self.mock_adapter._keys[pluggable_telemetry_storage._telemetry_config_key + "::" + pluggable_telemetry_storage._sdk_metadata] == '{"aF": 2, "rF": 1, "sT": "memory", "oM": 0, "t": []}') + +class PluggableRuleBasedSegmentStorageTests(object): + """In memory rule based segment storage test cases.""" + + def setup_method(self): + """Prepare storages with test data.""" + self.mock_adapter = StorageMockAdapter() + + def test_get(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_rbs_storage = PluggableRuleBasedSegmentsStorage(self.mock_adapter, prefix=sprefix) + + rbs1 = rule_based_segments.from_raw(rbsegments_json[0]) + rbs_name = rbsegments_json[0]['name'] + + self.mock_adapter.set(pluggable_rbs_storage._prefix.format(segment_name=rbs_name), rbs1.to_json()) + assert(pluggable_rbs_storage.get(rbs_name).to_json() == rule_based_segments.from_raw(rbsegments_json[0]).to_json()) + assert(pluggable_rbs_storage.get('not_existing') == None) + + def test_get_change_number(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_rbs_storage = PluggableRuleBasedSegmentsStorage(self.mock_adapter, prefix=sprefix) + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + self.mock_adapter.set(prefix + "SPLITIO.rbsegments.till", 1234) + assert(pluggable_rbs_storage.get_change_number() == 1234) + + def test_get_segment_names(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_rbs_storage = PluggableRuleBasedSegmentsStorage(self.mock_adapter, prefix=sprefix) + rbs1 = rule_based_segments.from_raw(rbsegments_json[0]) + rbs2_temp = copy.deepcopy(rbsegments_json[0]) + rbs2_temp['name'] = 'another_segment' + rbs2 = rule_based_segments.from_raw(rbs2_temp) + self.mock_adapter.set(pluggable_rbs_storage._prefix.format(segment_name=rbs1.name), rbs1.to_json()) + self.mock_adapter.set(pluggable_rbs_storage._prefix.format(segment_name=rbs2.name), rbs2.to_json()) + assert(pluggable_rbs_storage.get_segment_names() == [rbs1.name, rbs2.name]) + + def test_contains(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_rbs_storage = PluggableRuleBasedSegmentsStorage(self.mock_adapter, prefix=sprefix) + rbs1 = rule_based_segments.from_raw(rbsegments_json[0]) + rbs2_temp = copy.deepcopy(rbsegments_json[0]) + rbs2_temp['name'] = 'another_segment' + rbs2 = rule_based_segments.from_raw(rbs2_temp) + self.mock_adapter.set(pluggable_rbs_storage._prefix.format(segment_name=rbs1.name), rbs1.to_json()) + self.mock_adapter.set(pluggable_rbs_storage._prefix.format(segment_name=rbs2.name), rbs2.to_json()) + + assert(pluggable_rbs_storage.contains([rbs1.name, rbs2.name])) + assert(pluggable_rbs_storage.contains([rbs2.name])) + assert(not pluggable_rbs_storage.contains(['none-exists', rbs2.name])) + assert(not pluggable_rbs_storage.contains(['none-exists', 'none-exists2'])) + +class PluggableRuleBasedSegmentStorageAsyncTests(object): + """In memory rule based segment storage test cases.""" + + def setup_method(self): + """Prepare storages with test data.""" + self.mock_adapter = StorageMockAdapterAsync() + + @pytest.mark.asyncio + async def test_get(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_rbs_storage = PluggableRuleBasedSegmentsStorageAsync(self.mock_adapter, prefix=sprefix) + + rbs1 = rule_based_segments.from_raw(rbsegments_json[0]) + rbs_name = rbsegments_json[0]['name'] + + await self.mock_adapter.set(pluggable_rbs_storage._prefix.format(segment_name=rbs_name), rbs1.to_json()) + rbs = await pluggable_rbs_storage.get(rbs_name) + assert(rbs.to_json() == rule_based_segments.from_raw(rbsegments_json[0]).to_json()) + assert(await pluggable_rbs_storage.get('not_existing') == None) + + @pytest.mark.asyncio + async def test_get_change_number(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_rbs_storage = PluggableRuleBasedSegmentsStorageAsync(self.mock_adapter, prefix=sprefix) + if sprefix == 'myprefix': + prefix = 'myprefix.' + else: + prefix = '' + await self.mock_adapter.set(prefix + "SPLITIO.rbsegments.till", 1234) + assert(await pluggable_rbs_storage.get_change_number() == 1234) + + @pytest.mark.asyncio + async def test_get_segment_names(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_rbs_storage = PluggableRuleBasedSegmentsStorageAsync(self.mock_adapter, prefix=sprefix) + rbs1 = rule_based_segments.from_raw(rbsegments_json[0]) + rbs2_temp = copy.deepcopy(rbsegments_json[0]) + rbs2_temp['name'] = 'another_segment' + rbs2 = rule_based_segments.from_raw(rbs2_temp) + await self.mock_adapter.set(pluggable_rbs_storage._prefix.format(segment_name=rbs1.name), rbs1.to_json()) + await self.mock_adapter.set(pluggable_rbs_storage._prefix.format(segment_name=rbs2.name), rbs2.to_json()) + assert(await pluggable_rbs_storage.get_segment_names() == [rbs1.name, rbs2.name]) + + @pytest.mark.asyncio + async def test_contains(self): + self.mock_adapter._keys = {} + for sprefix in [None, 'myprefix']: + pluggable_rbs_storage = PluggableRuleBasedSegmentsStorageAsync(self.mock_adapter, prefix=sprefix) + rbs1 = rule_based_segments.from_raw(rbsegments_json[0]) + rbs2_temp = copy.deepcopy(rbsegments_json[0]) + rbs2_temp['name'] = 'another_segment' + rbs2 = rule_based_segments.from_raw(rbs2_temp) + await self.mock_adapter.set(pluggable_rbs_storage._prefix.format(segment_name=rbs1.name), rbs1.to_json()) + await self.mock_adapter.set(pluggable_rbs_storage._prefix.format(segment_name=rbs2.name), rbs2.to_json()) + + assert(await pluggable_rbs_storage.contains([rbs1.name, rbs2.name])) + assert(await pluggable_rbs_storage.contains([rbs2.name])) + assert(not await pluggable_rbs_storage.contains(['none-exists', rbs2.name])) + assert(not await pluggable_rbs_storage.contains(['none-exists', 'none-exists2'])) diff --git a/tests/storage/test_redis.py b/tests/storage/test_redis.py index 2a239904..a45c4ad2 100644 --- a/tests/storage/test_redis.py +++ b/tests/storage/test_redis.py @@ -3,15 +3,24 @@ import json import time - -from splitio.client.util import get_metadata -from splitio.storage.redis import RedisEventsStorage, RedisImpressionsStorage, \ - RedisSegmentStorage, RedisSplitStorage +import unittest.mock as mock +import redis.asyncio as aioredis +import pytest + +from splitio.client.util import get_metadata, SdkMetadata +from splitio.storage.adapters.redis import RedisAdapter, RedisAdapterAsync, RedisAdapterException, build +from splitio.optional.loaders import asyncio +from splitio.storage import FlagSetsFilter +from splitio.storage.redis import RedisEventsStorage, RedisEventsStorageAsync, RedisImpressionsStorage, RedisImpressionsStorageAsync, \ + RedisSegmentStorage, RedisSegmentStorageAsync, RedisSplitStorage, RedisSplitStorageAsync, RedisTelemetryStorage, RedisTelemetryStorageAsync, \ + RedisRuleBasedSegmentsStorage, RedisRuleBasedSegmentsStorageAsync +from splitio.storage.adapters.redis import RedisAdapter, RedisAdapterException, build +from redis.asyncio.client import Redis as aioredis +from splitio.storage.adapters import redis from splitio.models.segments import Segment from splitio.models.impressions import Impression from splitio.models.events import Event, EventWrapper -from splitio.storage.adapters.redis import RedisAdapter, RedisAdapterException - +from splitio.models.telemetry import MethodExceptions, MethodLatencies, TelemetryConfig, MethodExceptionsAndLatencies, TelemetryConfigAsync class RedisSplitStorageTests(object): """Redis split storage test cases.""" @@ -169,6 +178,275 @@ def test_is_valid_traffic_type_with_cache(self, mocker): time.sleep(1) assert storage.is_valid_traffic_type('any') is False + @mock.patch('splitio.storage.adapters.redis.RedisPipelineAdapter.execute', return_value = [{'split1', 'split2'}]) + def test_flag_sets(self, mocker): + """Test Flag sets scenarios.""" + adapter = build({}) + storage = RedisSplitStorage(adapter, True, 1) + assert storage.flag_set_filter.flag_sets == set({}) + assert sorted(storage.get_feature_flags_by_sets(['set1', 'set2'])) == ['split1', 'split2'] + + storage.flag_set_filter = FlagSetsFilter(['set2', 'set3']) + assert storage.get_feature_flags_by_sets(['set1']) == [] + assert sorted(storage.get_feature_flags_by_sets(['set2'])) == ['split1', 'split2'] + + storage2 = RedisSplitStorage(adapter, True, 1, ['set2', 'set3']) + assert storage2.flag_set_filter.flag_sets == set({'set2', 'set3'}) + +class RedisSplitStorageAsyncTests(object): + """Redis split storage test cases.""" + + @pytest.mark.asyncio + async def test_get_split(self, mocker): + """Test retrieving a split works.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + + self.redis_ret = None + self.name = None + async def get(sel, name): + self.name = name + self.redis_ret = '{"name": "some_split"}' + return self.redis_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get) + + from_raw = mocker.Mock() + mocker.patch('splitio.storage.redis.splits.from_raw', new=from_raw) + + storage = RedisSplitStorageAsync(adapter) + await storage.get('some_split') + + assert self.name == 'SPLITIO.split.some_split' + assert self.redis_ret == '{"name": "some_split"}' + + # Test that a missing split returns None and doesn't call from_raw + from_raw.reset_mock() + self.name = None + async def get2(sel, name): + self.name = name + return None + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get2) + + result = await storage.get('some_split') + assert result is None + assert self.name == 'SPLITIO.split.some_split' + assert not from_raw.mock_calls + + @pytest.mark.asyncio + async def test_get_split_with_cache(self, mocker): + """Test retrieving a split works.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + + self.redis_ret = None + self.name = None + async def get(sel, name): + self.name = name + self.redis_ret = '{"name": "some_split"}' + return self.redis_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get) + + from_raw = mocker.Mock() + mocker.patch('splitio.storage.redis.splits.from_raw', new=from_raw) + + storage = RedisSplitStorageAsync(adapter, True, 1) + await storage.get('some_split') + assert self.name == 'SPLITIO.split.some_split' + assert self.redis_ret == '{"name": "some_split"}' + + # hit the cache: + self.name = None + await storage.get('some_split') + self.name = None + await storage.get('some_split') + self.name = None + await storage.get('some_split') + assert self.name == None + + # Test that a missing split returns None and doesn't call from_raw + from_raw.reset_mock() + self.name = None + async def get2(sel, name): + self.name = name + return None + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get2) + + # Still cached + result = await storage.get('some_split') + assert result is not None + assert self.name == None + await asyncio.sleep(1) # wait for expiration + result = await storage.get('some_split') + assert self.name == 'SPLITIO.split.some_split' + assert result is None + + @pytest.mark.asyncio + async def test_get_splits_with_cache(self, mocker): + """Test retrieving a list of passed splits.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisSplitStorageAsync(adapter, True, 1) + + self.redis_ret = None + self.name = None + async def mget(sel, name): + self.name = name + self.redis_ret = ['{"name": "split1"}', '{"name": "split2"}', None] + return self.redis_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.mget', new=mget) + + from_raw = mocker.Mock() + mocker.patch('splitio.storage.redis.splits.from_raw', new=from_raw) + + result = await storage.fetch_many(['split1', 'split2', 'split3']) + assert len(result) == 3 + + assert '{"name": "split1"}' in self.redis_ret + assert '{"name": "split2"}' in self.redis_ret + + assert result['split1'] is not None + assert result['split2'] is not None + assert 'split3' in result + + # fetch again + self.name = None + result = await storage.fetch_many(['split1', 'split2', 'split3']) + assert result['split1'] is not None + assert result['split2'] is not None + assert 'split3' in result + assert self.name == None + + # wait for expire + await asyncio.sleep(1) + self.name = None + result = await storage.fetch_many(['split1', 'split2', 'split3']) + assert self.name == ['SPLITIO.split.split1', 'SPLITIO.split.split2', 'SPLITIO.split.split3'] + + @pytest.mark.asyncio + async def test_get_changenumber(self, mocker): + """Test fetching changenumber.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisSplitStorageAsync(adapter) + + self.redis_ret = None + self.name = None + async def get(sel, name): + self.name = name + self.redis_ret = '-1' + return self.redis_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get) + + assert await storage.get_change_number() == -1 + assert self.name == 'SPLITIO.splits.till' + + @pytest.mark.asyncio + async def test_get_all_splits(self, mocker): + """Test fetching all splits.""" + from_raw = mocker.Mock() + mocker.patch('splitio.storage.redis.splits.from_raw', new=from_raw) + + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisSplitStorageAsync(adapter) + + self.redis_ret = None + self.name = None + async def mget(sel, name): + self.name = name + self.redis_ret = ['{"name": "split1"}', '{"name": "split2"}', '{"name": "split3"}'] + return self.redis_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.mget', new=mget) + + self.key = None + self.keys_ret = None + async def keys(sel, key): + self.key = key + self.keys_ret = [ + 'SPLITIO.split.split1', + 'SPLITIO.split.split2', + 'SPLITIO.split.split3' + ] + return self.keys_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.keys', new=keys) + + await storage.get_all_splits() + + assert self.key == 'SPLITIO.split.*' + assert self.keys_ret == ['SPLITIO.split.split1', 'SPLITIO.split.split2', 'SPLITIO.split.split3'] + assert len(from_raw.mock_calls) == 3 + assert mocker.call({'name': 'split1'}) in from_raw.mock_calls + assert mocker.call({'name': 'split2'}) in from_raw.mock_calls + assert mocker.call({'name': 'split3'}) in from_raw.mock_calls + + @pytest.mark.asyncio + async def test_get_split_names(self, mocker): + """Test getching split names.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisSplitStorageAsync(adapter) + + self.key = None + self.keys_ret = None + async def keys(sel, key): + self.key = key + self.keys_ret = [ + 'SPLITIO.split.split1', + 'SPLITIO.split.split2', + 'SPLITIO.split.split3' + ] + return self.keys_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.keys', new=keys) + + assert await storage.get_split_names() == ['split1', 'split2', 'split3'] + + @pytest.mark.asyncio + async def test_is_valid_traffic_type(self, mocker): + """Test that traffic type validation works.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisSplitStorageAsync(adapter) + + async def get(sel, name): + return '1' + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get) + assert await storage.is_valid_traffic_type('any') is True + + async def get2(sel, name): + return '0' + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get2) + assert await storage.is_valid_traffic_type('any') is False + + async def get3(sel, name): + return None + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get3) + assert await storage.is_valid_traffic_type('any') is False + + @pytest.mark.asyncio + async def test_is_valid_traffic_type_with_cache(self, mocker): + """Test that traffic type validation works.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisSplitStorageAsync(adapter, True, 1) + + async def get(sel, name): + return '1' + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get) + assert await storage.is_valid_traffic_type('any') is True + + async def get2(sel, name): + return '0' + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get2) + assert await storage.is_valid_traffic_type('any') is True + await asyncio.sleep(1) + assert await storage.is_valid_traffic_type('any') is False + + async def get3(sel, name): + return None + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get3) + await asyncio.sleep(1) + assert await storage.is_valid_traffic_type('any') is False + class RedisSegmentStorageTests(object): """Redis segment storage test cases.""" @@ -220,6 +498,84 @@ def test_segment_contains(self, mocker): mocker.call('SPLITIO.segment.some_segment', 'some_key') ] +class RedisSegmentStorageAsyncTests(object): + """Redis segment storage test cases.""" + + @pytest.mark.asyncio + async def test_fetch_segment(self, mocker): + """Test fetching a whole segment.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + + self.key = None + async def smembers(key): + self.key = key + return set(["key1", "key2", "key3"]) + adapter.smembers = smembers + + self.key2 = None + async def get(key): + self.key2 = key + return '100' + adapter.get = get + + from_raw = mocker.Mock() + mocker.patch('splitio.storage.redis.segments.from_raw', new=from_raw) + + storage = RedisSegmentStorageAsync(adapter) + result = await storage.get('some_segment') + assert isinstance(result, Segment) + assert result.name == 'some_segment' + assert result.contains('key1') + assert result.contains('key2') + assert result.contains('key3') + assert result.change_number == 100 + assert self.key == 'SPLITIO.segment.some_segment' + assert self.key2 == 'SPLITIO.segment.some_segment.till' + + # Assert that if segment doesn't exist, None is returned + from_raw.reset_mock() + async def smembers2(key): + self.key = key + return set() + adapter.smembers = smembers2 + assert await storage.get('some_segment') is None + + @pytest.mark.asyncio + async def test_fetch_change_number(self, mocker): + """Test fetching change number.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + + self.key = None + async def get(key): + self.key = key + return '100' + adapter.get = get + + storage = RedisSegmentStorageAsync(adapter) + result = await storage.get_change_number('some_segment') + assert result == 100 + assert self.key == 'SPLITIO.segment.some_segment.till' + + @pytest.mark.asyncio + async def test_segment_contains(self, mocker): + """Test segment contains functionality.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisSegmentStorageAsync(adapter) + self.key = None + self.segment = None + async def sismember(segment, key): + self.key = key + self.segment = segment + return True + adapter.sismember = sismember + + assert await storage.segment_contains('some_segment', 'some_key') is True + assert self.segment == 'SPLITIO.segment.some_segment' + assert self.key == 'some_key' + class RedisImpressionsStorageTests(object): # pylint: disable=too-few-public-methods """Redis Impressions storage test cases.""" @@ -231,10 +587,10 @@ def test_wrap_impressions(self, mocker): storage = RedisImpressionsStorage(adapter, metadata) impressions = [ - Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654), - Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), - Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), - Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654) + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None) ] to_validate = [json.dumps({ @@ -251,6 +607,7 @@ def test_wrap_impressions(self, mocker): 'r': impression.label, 'c': impression.change_number, 'm': impression.time, + 'properties': None } }) for impression in impressions] @@ -263,10 +620,10 @@ def test_add_impressions(self, mocker): storage = RedisImpressionsStorage(adapter, metadata) impressions = [ - Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654), - Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), - Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), - Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654) + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None) ] assert storage.put(impressions) is True @@ -285,6 +642,7 @@ def test_add_impressions(self, mocker): 'r': impression.label, 'c': impression.change_number, 'm': impression.time, + 'properties': None } }) for impression in impressions] @@ -305,11 +663,111 @@ def test_add_impressions_to_pipe(self, mocker): storage = RedisImpressionsStorage(adapter, metadata) impressions = [ - Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654), - Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), - Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), - Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654) + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None) + ] + + to_validate = [json.dumps({ + 'm': { # METADATA PORTION + 's': metadata.sdk_version, + 'n': metadata.instance_name, + 'i': metadata.instance_ip, + }, + 'i': { # IMPRESSION PORTION + 'k': impression.matching_key, + 'b': impression.bucketing_key, + 'f': impression.feature_name, + 't': impression.treatment, + 'r': impression.label, + 'c': impression.change_number, + 'm': impression.time, + 'properties': None + } + }) for impression in impressions] + + storage.add_impressions_to_pipe(impressions, adapter) + assert adapter.rpush.mock_calls == [mocker.call('SPLITIO.impressions', *to_validate)] + + def test_expire_key(self, mocker): + adapter = mocker.Mock(spec=RedisAdapter) + metadata = get_metadata({}) + storage = RedisImpressionsStorage(adapter, metadata) + + self.key = None + self.ttl = None + def expire(key, ttl): + self.key = key + self.ttl = ttl + adapter.expire = expire + + storage.expire_key(2, 2) + assert self.key == 'SPLITIO.impressions' + assert self.ttl == 3600 + + self.key = None + storage.expire_key(2, 1) + assert self.key == None + + +class RedisImpressionsStorageAsyncTests(object): # pylint: disable=too-few-public-methods + """Redis Impressions async storage test cases.""" + + def test_wrap_impressions(self, mocker): + """Test wrap impressions.""" + adapter = mocker.Mock(spec=RedisAdapterAsync) + metadata = get_metadata({}) + storage = RedisImpressionsStorageAsync(adapter, metadata) + + impressions = [ + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None) + ] + + to_validate = [json.dumps({ + 'm': { # METADATA PORTION + 's': metadata.sdk_version, + 'n': metadata.instance_name, + 'i': metadata.instance_ip, + }, + 'i': { # IMPRESSION PORTION + 'k': impression.matching_key, + 'b': impression.bucketing_key, + 'f': impression.feature_name, + 't': impression.treatment, + 'r': impression.label, + 'c': impression.change_number, + 'm': impression.time, + 'properties': None + } + }) for impression in impressions] + + assert storage._wrap_impressions(impressions) == to_validate + + @pytest.mark.asyncio + async def test_add_impressions(self, mocker): + """Test that adding impressions to storage works.""" + adapter = mocker.Mock(spec=RedisAdapterAsync) + metadata = get_metadata({}) + storage = RedisImpressionsStorageAsync(adapter, metadata) + + impressions = [ + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None) ] + self.key = None + self.imps = None + async def rpush(key, *imps): + self.key = key + self.imps = imps + + adapter.rpush = rpush + assert await storage.put(impressions) is True to_validate = [json.dumps({ 'm': { # METADATA PORTION @@ -325,12 +783,77 @@ def test_add_impressions_to_pipe(self, mocker): 'r': impression.label, 'c': impression.change_number, 'm': impression.time, + 'properties': None + } + }) for impression in impressions] + + assert self.key == 'SPLITIO.impressions' + assert self.imps == tuple(to_validate) + + # Assert that if an exception is thrown it's caught and False is returned + adapter.reset_mock() + + async def rpush2(key, *imps): + raise RedisAdapterException('something') + adapter.rpush = rpush2 + assert await storage.put(impressions) is False + + def test_add_impressions_to_pipe(self, mocker): + """Test that adding impressions to storage works.""" + adapter = mocker.Mock(spec=RedisAdapterAsync) + metadata = get_metadata({}) + storage = RedisImpressionsStorageAsync(adapter, metadata) + + impressions = [ + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654, None, None), + Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654, None, None) + ] + + to_validate = [json.dumps({ + 'm': { # METADATA PORTION + 's': metadata.sdk_version, + 'n': metadata.instance_name, + 'i': metadata.instance_ip, + }, + 'i': { # IMPRESSION PORTION + 'k': impression.matching_key, + 'b': impression.bucketing_key, + 'f': impression.feature_name, + 't': impression.treatment, + 'r': impression.label, + 'c': impression.change_number, + 'm': impression.time, + 'properties': None } }) for impression in impressions] storage.add_impressions_to_pipe(impressions, adapter) assert adapter.rpush.mock_calls == [mocker.call('SPLITIO.impressions', *to_validate)] + @pytest.mark.asyncio + async def test_expire_key(self, mocker): + adapter = mocker.Mock(spec=RedisAdapterAsync) + metadata = get_metadata({}) + storage = RedisImpressionsStorageAsync(adapter, metadata) + + self.key = None + self.ttl = None + async def expire(key, ttl): + self.key = key + self.ttl = ttl + adapter.expire = expire + + await storage.expire_key(2, 2) + assert self.key == 'SPLITIO.impressions' + assert self.ttl == 3600 + + self.key = None + await storage.expire_key(2, 1) + assert self.key == None + + class RedisEventsStorageTests(object): # pylint: disable=too-few-public-methods """Redis Impression storage test cases.""" @@ -380,3 +903,546 @@ def _raise_exc(*_): raise RedisAdapterException('something') adapter.rpush.side_effect = _raise_exc assert storage.put(events) is False + + def test_expire_keys(self, mocker): + adapter = mocker.Mock(spec=RedisAdapter) + metadata = get_metadata({}) + storage = RedisEventsStorage(adapter, metadata) + + self.key = None + self.ttl = None + def expire(key, ttl): + self.key = key + self.ttl = ttl + adapter.expire = expire + + storage.expire_keys(2, 2) + assert self.key == 'SPLITIO.events' + assert self.ttl == 3600 + + self.key = None + storage.expire_keys(2, 1) + assert self.key == None + +class RedisEventsStorageAsyncTests(object): # pylint: disable=too-few-public-methods + """Redis Impression async storage test cases.""" + + @pytest.mark.asyncio + async def test_add_events(self, mocker): + """Test that adding impressions to storage works.""" + adapter = mocker.Mock(spec=RedisAdapterAsync) + metadata = get_metadata({}) + + storage = RedisEventsStorageAsync(adapter, metadata) + + events = [ + EventWrapper(event=Event('key1', 'user', 'purchase', 10, 123456, None), size=32768), + EventWrapper(event=Event('key2', 'user', 'purchase', 10, 123456, None), size=32768), + EventWrapper(event=Event('key3', 'user', 'purchase', 10, 123456, None), size=32768), + EventWrapper(event=Event('key4', 'user', 'purchase', 10, 123456, None), size=32768), + ] + self.key = None + self.events = None + async def rpush(key, *events): + self.key = key + self.events = events + adapter.rpush = rpush + + assert await storage.put(events) is True + + list_of_raw_events = [json.dumps({ + 'e': { # EVENT PORTION + 'key': e.event.key, + 'trafficTypeName': e.event.traffic_type_name, + 'eventTypeId': e.event.event_type_id, + 'value': e.event.value, + 'timestamp': e.event.timestamp, + 'properties': e.event.properties, + }, + 'm': { # METADATA PORTION + 's': metadata.sdk_version, + 'n': metadata.instance_name, + 'i': metadata.instance_ip, + } + }) for e in events] + + assert self.events == tuple(list_of_raw_events) + assert self.key == 'SPLITIO.events' + assert storage._wrap_events(events) == list_of_raw_events + + # Assert that if an exception is thrown it's caught and False is returned + adapter.reset_mock() + + async def rpush2(key, *events): + raise RedisAdapterException('something') + adapter.rpush = rpush2 + assert await storage.put(events) is False + + + @pytest.mark.asyncio + async def test_expire_keys(self, mocker): + adapter = mocker.Mock(spec=RedisAdapterAsync) + metadata = get_metadata({}) + storage = RedisEventsStorageAsync(adapter, metadata) + + self.key = None + self.ttl = None + async def expire(key, ttl): + self.key = key + self.ttl = ttl + adapter.expire = expire + + await storage.expire_keys(2, 2) + assert self.key == 'SPLITIO.events' + assert self.ttl == 3600 + + self.key = None + await storage.expire_keys(2, 1) + assert self.key == None + + +class RedisTelemetryStorageTests(object): + """Redis Telemetry storage test cases.""" + + def test_init(self, mocker): + redis_telemetry = RedisTelemetryStorage(mocker.Mock(), mocker.Mock()) + assert(redis_telemetry._redis_client is not None) + assert(redis_telemetry._sdk_metadata is not None) + assert(isinstance(redis_telemetry._tel_config, TelemetryConfig)) + assert(redis_telemetry._make_pipe is not None) + + @mock.patch('splitio.models.telemetry.TelemetryConfig.record_config') + def test_record_config(self, mocker): + redis_telemetry = RedisTelemetryStorage(mocker.Mock(), mocker.Mock()) + redis_telemetry.record_config(mocker.Mock(), mocker.Mock(), 0, 0) + assert(mocker.called) + + @mock.patch('splitio.storage.adapters.redis.RedisAdapter.hset') + def test_push_config_stats(self, mocker): + adapter = build({}) + redis_telemetry = RedisTelemetryStorage(adapter, mocker.Mock()) + redis_telemetry.push_config_stats() + assert(mocker.called) + + def test_format_config_stats(self, mocker): + redis_telemetry = RedisTelemetryStorage(mocker.Mock(), mocker.Mock()) + json_value = redis_telemetry._format_config_stats({'aF': 0, 'rF': 0, 'sT': None, 'oM': None}, []) + stats = redis_telemetry._tel_config.get_stats() + assert(json_value == json.dumps({ + 'aF': stats['aF'], + 'rF': stats['rF'], + 'sT': stats['sT'], + 'oM': stats['oM'], + 't': redis_telemetry.pop_config_tags(), + })) + + def test_record_active_and_redundant_factories(self, mocker): + redis_telemetry = RedisTelemetryStorage(mocker.Mock(), mocker.Mock()) + active_factory_count = 1 + redundant_factory_count = 2 + redis_telemetry.record_active_and_redundant_factories(1, 2) + assert (redis_telemetry._tel_config._active_factory_count == active_factory_count) + assert (redis_telemetry._tel_config._redundant_factory_count == redundant_factory_count) + + def test_add_latency_to_pipe(self, mocker): + adapter = build({}) + metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + redis_telemetry = RedisTelemetryStorage(adapter, metadata) + pipe = adapter._decorated.pipeline() + + def _mocked_hincrby(*args, **kwargs): + assert(args[1] == RedisTelemetryStorage._TELEMETRY_LATENCIES_KEY) + assert(args[2][-11:] == 'treatment/0') + assert(args[3] == 1) + # should increment bucket 0 + with mock.patch('redis.client.Pipeline.hincrby', _mocked_hincrby): + redis_telemetry.add_latency_to_pipe(MethodExceptionsAndLatencies.TREATMENT, 0, pipe) + + def _mocked_hincrby2(*args, **kwargs): + assert(args[1] == RedisTelemetryStorage._TELEMETRY_LATENCIES_KEY) + assert(args[2][-11:] == 'treatment/3') + assert(args[3] == 1) + # should increment bucket 3 + with mock.patch('redis.client.Pipeline.hincrby', _mocked_hincrby2): + redis_telemetry.add_latency_to_pipe(MethodExceptionsAndLatencies.TREATMENT, 3, pipe) + + def test_record_exception(self, mocker): + def _mocked_hincrby(*args, **kwargs): + assert(args[1] == RedisTelemetryStorage._TELEMETRY_EXCEPTIONS_KEY) + assert(args[2] == 'python-1.1.1/hostname/ip/treatment') + assert(args[3] == 1) + + adapter = build({}) + metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + redis_telemetry = RedisTelemetryStorage(adapter, metadata) + with mock.patch('redis.client.Pipeline.hincrby', _mocked_hincrby): + with mock.patch('redis.client.Pipeline.execute') as mock_method: + mock_method.return_value = [1] + redis_telemetry.record_exception(MethodExceptionsAndLatencies.TREATMENT) + + def test_expire_latency_keys(self, mocker): + redis_telemetry = RedisTelemetryStorage(mocker.Mock(), mocker.Mock()) + def _mocked_method(*args, **kwargs): + assert(args[1] == RedisTelemetryStorage._TELEMETRY_LATENCIES_KEY) + assert(args[2] == RedisTelemetryStorage._TELEMETRY_KEY_DEFAULT_TTL) + assert(args[3] == 1) + assert(args[4] == 2) + + with mock.patch('splitio.storage.redis.RedisTelemetryStorage.expire_keys', _mocked_method): + redis_telemetry.expire_latency_keys(1, 2) + + @mock.patch('redis.client.Redis.expire') + def test_expire_keys(self, mocker): + adapter = build({}) + metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + redis_telemetry = RedisTelemetryStorage(adapter, metadata) + redis_telemetry.expire_keys('key', 12, 1, 2) + assert(not mocker.called) + redis_telemetry.expire_keys('key', 12, 2, 2) + assert(mocker.called) + + +class RedisTelemetryStorageAsyncTests(object): + """Redis Telemetry storage test cases.""" + + @pytest.mark.asyncio + async def test_init(self, mocker): + redis_telemetry = await RedisTelemetryStorageAsync.create(mocker.Mock(), mocker.Mock()) + assert(redis_telemetry._redis_client is not None) + assert(redis_telemetry._sdk_metadata is not None) + assert(isinstance(redis_telemetry._tel_config, TelemetryConfigAsync)) + assert(redis_telemetry._make_pipe is not None) + + @pytest.mark.asyncio + async def test_record_config(self, mocker): + redis_telemetry = await RedisTelemetryStorageAsync.create(mocker.Mock(), mocker.Mock()) + self.called = False + async def record_config(*args): + self.called = True + redis_telemetry._tel_config.record_config = record_config + + await redis_telemetry.record_config(mocker.Mock(), mocker.Mock(), 0, 0) + assert(self.called) + + @pytest.mark.asyncio + async def test_push_config_stats(self, mocker): + adapter = await aioredis.from_url("redis://localhost") + redis_telemetry = await RedisTelemetryStorageAsync.create(adapter, SdkMetadata('python-1.1.1', 'hostname', 'ip')) + self.key = None + self.hash = None + async def hset(key, hash, val): + self.key = key + self.hash = hash + + adapter.hset = hset + def format_config_stats(stats, tags): + return "" + redis_telemetry._format_config_stats = format_config_stats + await redis_telemetry.push_config_stats() + assert self.key == 'SPLITIO.telemetry.init' + assert self.hash == 'python-1.1.1/hostname/ip' + + @pytest.mark.asyncio + async def test_format_config_stats(self, mocker): + redis_telemetry = await RedisTelemetryStorageAsync.create(mocker.Mock(), mocker.Mock()) + json_value = redis_telemetry._format_config_stats({'aF': 0, 'rF': 0, 'sT': None, 'oM': None}, []) + stats = await redis_telemetry._tel_config.get_stats() + assert(json_value == json.dumps({ + 'aF': stats['aF'], + 'rF': stats['rF'], + 'sT': stats['sT'], + 'oM': stats['oM'], + 't': await redis_telemetry.pop_config_tags() + })) + + @pytest.mark.asyncio + async def test_record_active_and_redundant_factories(self, mocker): + redis_telemetry = await RedisTelemetryStorageAsync.create(mocker.Mock(), mocker.Mock()) + active_factory_count = 1 + redundant_factory_count = 2 + await redis_telemetry.record_active_and_redundant_factories(1, 2) + assert (redis_telemetry._tel_config._active_factory_count == active_factory_count) + assert (redis_telemetry._tel_config._redundant_factory_count == redundant_factory_count) + + @pytest.mark.asyncio + async def test_add_latency_to_pipe(self, mocker): + adapter = build({}) + metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + redis_telemetry = await RedisTelemetryStorageAsync.create(adapter, metadata) + pipe = adapter._decorated.pipeline() + + def _mocked_hincrby(*args, **kwargs): + assert(args[1] == RedisTelemetryStorageAsync._TELEMETRY_LATENCIES_KEY) + assert(args[2][-11:] == 'treatment/0') + assert(args[3] == 1) + # should increment bucket 0 + with mock.patch('redis.client.Pipeline.hincrby', _mocked_hincrby): + redis_telemetry.add_latency_to_pipe(MethodExceptionsAndLatencies.TREATMENT, 0, pipe) + + def _mocked_hincrby2(*args, **kwargs): + assert(args[1] == RedisTelemetryStorageAsync._TELEMETRY_LATENCIES_KEY) + assert(args[2][-11:] == 'treatment/3') + assert(args[3] == 1) + # should increment bucket 3 + with mock.patch('redis.client.Pipeline.hincrby', _mocked_hincrby2): + redis_telemetry.add_latency_to_pipe(MethodExceptionsAndLatencies.TREATMENT, 3, pipe) + + @pytest.mark.asyncio + async def test_record_exception(self, mocker): + self.called = False + def _mocked_hincrby(*args, **kwargs): + self.called = True + assert(args[1] == RedisTelemetryStorageAsync._TELEMETRY_EXCEPTIONS_KEY) + assert(args[2] == 'python-1.1.1/hostname/ip/treatment') + assert(args[3] == 1) + + self.called2 = False + async def _mocked_execute(*args): + self.called2 = True + return [1] + + adapter = await aioredis.from_url("redis://localhost") + metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + redis_telemetry = await RedisTelemetryStorageAsync.create(adapter, metadata) + with mock.patch('redis.asyncio.client.Pipeline.hincrby', _mocked_hincrby): + with mock.patch('redis.asyncio.client.Pipeline.execute', _mocked_execute): + await redis_telemetry.record_exception(MethodExceptionsAndLatencies.TREATMENT) + assert self.called + assert self.called2 + + @pytest.mark.asyncio + async def test_expire_latency_keys(self, mocker): + redis_telemetry = await RedisTelemetryStorageAsync.create(mocker.Mock(), mocker.Mock()) + def _mocked_method(*args, **kwargs): + assert(args[1] == RedisTelemetryStorageAsync._TELEMETRY_LATENCIES_KEY) + assert(args[2] == RedisTelemetryStorageAsync._TELEMETRY_KEY_DEFAULT_TTL) + assert(args[3] == 1) + assert(args[4] == 2) + + with mock.patch('splitio.storage.redis.RedisTelemetryStorage.expire_keys', _mocked_method): + await redis_telemetry.expire_latency_keys(1, 2) + + @pytest.mark.asyncio + async def test_expire_keys(self, mocker): + adapter = await aioredis.from_url("redis://localhost") + metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + redis_telemetry = await RedisTelemetryStorageAsync.create(adapter, metadata) + self.called = False + async def expire(*args): + self.called = True + adapter.expire = expire + + await redis_telemetry.expire_keys('key', 12, 1, 2) + assert(not self.called) + + await redis_telemetry.expire_keys('key', 12, 2, 2) + assert(self.called) + +class RedisRuleBasedSegmentStorageTests(object): + """Redis rule based segment storage test cases.""" + + def test_get_segment(self, mocker): + """Test retrieving a rule based segment works.""" + adapter = mocker.Mock(spec=RedisAdapter) + adapter.get.return_value = '{"name": "some_segment"}' + from_raw = mocker.Mock() + mocker.patch('splitio.storage.redis.rule_based_segments.from_raw', new=from_raw) + + storage = RedisRuleBasedSegmentsStorage(adapter) + storage.get('some_segment') + + assert adapter.get.mock_calls == [mocker.call('SPLITIO.rbsegment.some_segment')] + assert from_raw.mock_calls == [mocker.call({"name": "some_segment"})] + + # Test that a missing split returns None and doesn't call from_raw + adapter.reset_mock() + from_raw.reset_mock() + adapter.get.return_value = None + result = storage.get('some_segment') + assert result is None + assert adapter.get.mock_calls == [mocker.call('SPLITIO.rbsegment.some_segment')] + assert not from_raw.mock_calls + + def test_get_changenumber(self, mocker): + """Test fetching changenumber.""" + adapter = mocker.Mock(spec=RedisAdapter) + storage = RedisRuleBasedSegmentsStorage(adapter) + adapter.get.return_value = '-1' + assert storage.get_change_number() == -1 + assert adapter.get.mock_calls == [mocker.call('SPLITIO.rbsegments.till')] + + def test_get_segment_names(self, mocker): + """Test getching rule based segment names.""" + adapter = mocker.Mock(spec=RedisAdapter) + storage = RedisRuleBasedSegmentsStorage(adapter) + adapter.keys.return_value = [ + 'SPLITIO.rbsegment.segment1', + 'SPLITIO.rbsegment.segment2', + 'SPLITIO.rbsegment.segment3' + ] + assert storage.get_segment_names() == ['segment1', 'segment2', 'segment3'] + + def test_contains(self, mocker): + """Test storage containing rule based segment names.""" + adapter = mocker.Mock(spec=RedisAdapter) + storage = RedisRuleBasedSegmentsStorage(adapter) + adapter.keys.return_value = [ + 'SPLITIO.rbsegment.segment1', + 'SPLITIO.rbsegment.segment2', + 'SPLITIO.rbsegment.segment3' + ] + assert storage.contains(['segment1', 'segment3']) + assert not storage.contains(['segment1', 'segment4']) + assert storage.contains(['segment1']) + assert not storage.contains(['segment4', 'segment5']) + + def test_fetch_many(self, mocker): + """Test retrieving a list of passed splits.""" + adapter = mocker.Mock(spec=RedisAdapter) + storage = RedisRuleBasedSegmentsStorage(adapter) + from_raw = mocker.Mock() + mocker.patch('splitio.storage.redis.rule_based_segments.from_raw', new=from_raw) + + adapter.mget.return_value = ['{"name": "rbs1"}', '{"name": "rbs2"}', None] + + result = storage.fetch_many(['rbs1', 'rbs2', 'rbs3']) + assert len(result) == 3 + + assert mocker.call({'name': 'rbs1'}) in from_raw.mock_calls + assert mocker.call({'name': 'rbs2'}) in from_raw.mock_calls + + assert result['rbs1'] is not None + assert result['rbs2'] is not None + assert 'rbs3' in result + + # should not raise exception + result = storage.fetch_many([]) + assert len(result) == 0 + +class RedisRuleBasedSegmentStorageAsyncTests(object): + """Redis rule based segment storage test cases.""" + + @pytest.mark.asyncio + async def test_get_segment(self, mocker): + """Test retrieving a rule based segment works.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + + self.redis_ret = None + self.name = None + async def get(sel, name): + self.name = name + self.redis_ret = '{"changeNumber": "12", "name": "some_segment", "status": "ACTIVE","trafficTypeName": "user","excluded":{"keys":[],"segments":[]},"conditions": []}' + return self.redis_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get) + + storage = RedisRuleBasedSegmentsStorageAsync(adapter) + await storage.get('some_segment') + + assert self.name == 'SPLITIO.rbsegment.some_segment' + assert self.redis_ret == '{"changeNumber": "12", "name": "some_segment", "status": "ACTIVE","trafficTypeName": "user","excluded":{"keys":[],"segments":[]},"conditions": []}' + + # Test that a missing split returns None and doesn't call from_raw + + self.name = None + async def get2(sel, name): + self.name = name + return None + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get2) + + result = await storage.get('some_segment') + assert result is None + assert self.name == 'SPLITIO.rbsegment.some_segment' + + # Test that a missing split returns None and doesn't call from_raw + result = await storage.get('some_segment2') + assert result is None + + @pytest.mark.asyncio + async def test_get_changenumber(self, mocker): + """Test fetching changenumber.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisRuleBasedSegmentsStorageAsync(adapter) + + self.redis_ret = None + self.name = None + async def get(sel, name): + self.name = name + self.redis_ret = '-1' + return self.redis_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get) + + assert await storage.get_change_number() == -1 + assert self.name == 'SPLITIO.rbsegments.till' + + @pytest.mark.asyncio + async def test_get_segment_names(self, mocker): + """Test getching rule based segment names.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisRuleBasedSegmentsStorageAsync(adapter) + + self.key = None + self.keys_ret = None + async def keys(sel, key): + self.key = key + self.keys_ret = [ + 'SPLITIO.rbsegment.segment1', + 'SPLITIO.rbsegment.segment2', + 'SPLITIO.rbsegment.segment3' + ] + return self.keys_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.keys', new=keys) + + assert await storage.get_segment_names() == ['segment1', 'segment2', 'segment3'] + + @pytest.mark.asyncio + async def test_contains(self, mocker): + """Test storage containing rule based segment names.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisRuleBasedSegmentsStorageAsync(adapter) + + self.key = None + self.keys_ret = None + async def keys(sel, key): + self.key = key + self.keys_ret = [ + 'SPLITIO.rbsegment.segment1', + 'SPLITIO.rbsegment.segment2', + 'SPLITIO.rbsegment.segment3' + ] + return self.keys_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.keys', new=keys) + + assert await storage.contains(['segment1', 'segment3']) + assert not await storage.contains(['segment1', 'segment4']) + assert await storage.contains(['segment1']) + assert not await storage.contains(['segment4', 'segment5']) + + @pytest.mark.asyncio + async def test_fetch_many(self, mocker): + """Test retrieving a list of passed splits.""" + adapter = mocker.Mock(spec=RedisAdapter) + storage = RedisRuleBasedSegmentsStorageAsync(adapter) + from_raw = mocker.Mock() + mocker.patch('splitio.storage.redis.rule_based_segments.from_raw', new=from_raw) + async def mget(*_): + return ['{"name": "rbs1"}', '{"name": "rbs2"}', None] + adapter.mget = mget + + result = await storage.fetch_many(['rbs1', 'rbs2', 'rbs3']) + assert len(result) == 3 + + assert mocker.call({'name': 'rbs1'}) in from_raw.mock_calls + assert mocker.call({'name': 'rbs2'}) in from_raw.mock_calls + + assert result['rbs1'] is not None + assert result['rbs2'] is not None + assert 'rbs3' in result + + # should not raise exception + result = await storage.fetch_many([]) + assert len(result) == 0 + diff --git a/tests/sync/test_events_synchronizer.py b/tests/sync/test_events_synchronizer.py index 862f695f..7eb52dc4 100644 --- a/tests/sync/test_events_synchronizer.py +++ b/tests/sync/test_events_synchronizer.py @@ -8,7 +8,7 @@ from splitio.api import APIException from splitio.storage import EventStorage from splitio.models.events import Event -from splitio.sync.event import EventSynchronizer +from splitio.sync.event import EventSynchronizer, EventSynchronizerAsync class EventsSynchronizerTests(object): @@ -57,7 +57,7 @@ def test_synchronize_impressions(self, mocker): def run(x): run._called += 1 - return HttpResponse(200, '') + return HttpResponse(200, '', {}) api.flush_events.side_effect = run run._called = 0 @@ -66,3 +66,66 @@ def run(x): event_synchronizer.synchronize_events() assert run._called == 1 assert event_synchronizer._failed.qsize() == 0 + + +class EventsSynchronizerAsyncTests(object): + """Events synchronizer async test cases.""" + + @pytest.mark.asyncio + async def test_synchronize_events_error(self, mocker): + storage = mocker.Mock(spec=EventStorage) + async def pop_many(*args): + return [ + Event('key1', 'user', 'purchase', 5.3, 123456, None), + Event('key2', 'user', 'purchase', 5.3, 123456, None), + ] + storage.pop_many = pop_many + + api = mocker.Mock() + async def run(x): + raise APIException("something broke") + + api.flush_events = run + event_synchronizer = EventSynchronizerAsync(api, storage, 5) + await event_synchronizer.synchronize_events() + assert event_synchronizer._failed.qsize() == 2 + + @pytest.mark.asyncio + async def test_synchronize_events_empty(self, mocker): + storage = mocker.Mock(spec=EventStorage) + async def pop_many(*args): + return [] + storage.pop_many = pop_many + + api = mocker.Mock() + async def run(x): + run._called += 1 + + run._called = 0 + api.flush_events = run + event_synchronizer = EventSynchronizerAsync(api, storage, 5) + await event_synchronizer.synchronize_events() + assert run._called == 0 + + @pytest.mark.asyncio + async def test_synchronize_impressions(self, mocker): + storage = mocker.Mock(spec=EventStorage) + async def pop_many(*args): + return [ + Event('key1', 'user', 'purchase', 5.3, 123456, None), + Event('key2', 'user', 'purchase', 5.3, 123456, None), + ] + storage.pop_many = pop_many + + api = mocker.Mock() + async def run(x): + run._called += 1 + return HttpResponse(200, '', {}) + + api.flush_events.side_effect = run + run._called = 0 + + event_synchronizer = EventSynchronizerAsync(api, storage, 5) + await event_synchronizer.synchronize_events() + assert run._called == 1 + assert event_synchronizer._failed.qsize() == 0 diff --git a/tests/sync/test_impressions_count_synchronizer.py b/tests/sync/test_impressions_count_synchronizer.py index 4f9f1ca4..3db1753e 100644 --- a/tests/sync/test_impressions_count_synchronizer.py +++ b/tests/sync/test_impressions_count_synchronizer.py @@ -6,9 +6,10 @@ from splitio.api.client import HttpResponse from splitio.api import APIException -from splitio.engine.impressions import Manager as ImpressionsManager -from splitio.engine.impressions import Counter -from splitio.sync.impression import ImpressionsCountSynchronizer +from splitio.engine.impressions.impressions import Manager as ImpressionsManager +from splitio.engine.impressions.manager import Counter +from splitio.engine.impressions.strategies import StrategyOptimizedMode +from splitio.sync.impression import ImpressionsCountSynchronizer, ImpressionsCountSynchronizerAsync from splitio.api.impressions import ImpressionsAPI @@ -16,7 +17,7 @@ class ImpressionsCountSynchronizerTests(object): """ImpressionsCount synchronizer test cases.""" def test_synchronize_impressions_counts(self, mocker): - manager = mocker.Mock(spec=ImpressionsManager) + counter = mocker.Mock(spec=Counter) counters = [ Counter.CountPerFeature('f1', 123, 2), @@ -25,13 +26,50 @@ def test_synchronize_impressions_counts(self, mocker): Counter.CountPerFeature('f2', 456, 222) ] - manager.get_counts.return_value = counters + counter.pop_all.return_value = counters api = mocker.Mock(spec=ImpressionsAPI) - api.flush_counters.return_value = HttpResponse(200, '') - impression_count_synchronizer = ImpressionsCountSynchronizer(api, manager) + api.flush_counters.return_value = HttpResponse(200, '', {}) + impression_count_synchronizer = ImpressionsCountSynchronizer(api, counter) impression_count_synchronizer.synchronize_counters() - assert manager.get_counts.mock_calls[0] == mocker.call() + assert counter.pop_all.mock_calls[0] == mocker.call() assert api.flush_counters.mock_calls[0] == mocker.call(counters) assert len(api.flush_counters.mock_calls) == 1 + + +class ImpressionsCountSynchronizerAsyncTests(object): + """ImpressionsCount synchronizer test cases.""" + + @pytest.mark.asyncio + async def test_synchronize_impressions_counts(self, mocker): + counter = mocker.Mock(spec=Counter) + + self.called = 0 + def pop_all(): + self.called += 1 + return [ + Counter.CountPerFeature('f1', 123, 2), + Counter.CountPerFeature('f2', 123, 123), + Counter.CountPerFeature('f1', 456, 111), + Counter.CountPerFeature('f2', 456, 222) + ] + counter.pop_all = pop_all + + self.counters = None + async def flush_counters(counters): + self.counters = counters + return HttpResponse(200, '', {}) + api = mocker.Mock(spec=ImpressionsAPI) + api.flush_counters = flush_counters + + impression_count_synchronizer = ImpressionsCountSynchronizerAsync(api, counter) + await impression_count_synchronizer.synchronize_counters() + + assert self.counters == [ + Counter.CountPerFeature('f1', 123, 2), + Counter.CountPerFeature('f2', 123, 123), + Counter.CountPerFeature('f1', 456, 111), + Counter.CountPerFeature('f2', 456, 222) + ] + assert self.called == 1 diff --git a/tests/sync/test_impressions_synchronizer.py b/tests/sync/test_impressions_synchronizer.py index 9d1a3848..00b65833 100644 --- a/tests/sync/test_impressions_synchronizer.py +++ b/tests/sync/test_impressions_synchronizer.py @@ -8,7 +8,7 @@ from splitio.api import APIException from splitio.storage import ImpressionStorage from splitio.models.impressions import Impression -from splitio.sync.impression import ImpressionSynchronizer +from splitio.sync.impression import ImpressionSynchronizer, ImpressionSynchronizerAsync class ImpressionsSynchronizerTests(object): @@ -17,8 +17,8 @@ class ImpressionsSynchronizerTests(object): def test_synchronize_impressions_error(self, mocker): storage = mocker.Mock(spec=ImpressionStorage) storage.pop_many.return_value = [ - Impression('key1', 'split1', 'on', 'l1', 123456, 'b1', 321654), - Impression('key2', 'split1', 'on', 'l1', 123456, 'b1', 321654), + Impression('key1', 'split1', 'on', 'l1', 123456, 'b1', 321654, None, None), + Impression('key2', 'split1', 'on', 'l1', 123456, 'b1', 321654, None, None), ] api = mocker.Mock() @@ -49,15 +49,15 @@ def run(x): def test_synchronize_impressions(self, mocker): storage = mocker.Mock(spec=ImpressionStorage) storage.pop_many.return_value = [ - Impression('key1', 'split1', 'on', 'l1', 123456, 'b1', 321654), - Impression('key2', 'split1', 'on', 'l1', 123456, 'b1', 321654), + Impression('key1', 'split1', 'on', 'l1', 123456, 'b1', 321654, None, None), + Impression('key2', 'split1', 'on', 'l1', 123456, 'b1', 321654, None, None), ] api = mocker.Mock() def run(x): run._called += 1 - return HttpResponse(200, '') + return HttpResponse(200, '', {}) api.flush_impressions.side_effect = run run._called = 0 @@ -66,3 +66,68 @@ def run(x): impression_synchronizer.synchronize_impressions() assert run._called == 1 assert impression_synchronizer._failed.qsize() == 0 + + +class ImpressionsSynchronizerAsyncTests(object): + """Impressions synchronizer test cases.""" + + @pytest.mark.asyncio + async def test_synchronize_impressions_error(self, mocker): + storage = mocker.Mock(spec=ImpressionStorage) + async def pop_many(*args): + return [ + Impression('key1', 'split1', 'on', 'l1', 123456, 'b1', 321654, None, None), + Impression('key2', 'split1', 'on', 'l1', 123456, 'b1', 321654, None, None), + ] + storage.pop_many = pop_many + api = mocker.Mock() + + async def run(x): + raise APIException("something broke") + api.flush_impressions = run + + impression_synchronizer = ImpressionSynchronizerAsync(api, storage, 5) + await impression_synchronizer.synchronize_impressions() + assert impression_synchronizer._failed.qsize() == 2 + + @pytest.mark.asyncio + async def test_synchronize_impressions_empty(self, mocker): + storage = mocker.Mock(spec=ImpressionStorage) + async def pop_many(*args): + return [] + storage.pop_many = pop_many + + api = mocker.Mock() + + async def run(x): + run._called += 1 + + run._called = 0 + api.flush_impressions = run + impression_synchronizer = ImpressionSynchronizerAsync(api, storage, 5) + await impression_synchronizer.synchronize_impressions() + assert run._called == 0 + + @pytest.mark.asyncio + async def test_synchronize_impressions(self, mocker): + storage = mocker.Mock(spec=ImpressionStorage) + async def pop_many(*args): + return [ + Impression('key1', 'split1', 'on', 'l1', 123456, 'b1', 321654, None, None), + Impression('key2', 'split1', 'on', 'l1', 123456, 'b1', 321654, None, None), + ] + storage.pop_many = pop_many + + api = mocker.Mock() + + async def run(x): + run._called += 1 + return HttpResponse(200, '', {}) + + api.flush_impressions = run + run._called = 0 + + impression_synchronizer = ImpressionSynchronizerAsync(api, storage, 5) + await impression_synchronizer.synchronize_impressions() + assert run._called == 1 + assert impression_synchronizer._failed.qsize() == 0 diff --git a/tests/sync/test_manager.py b/tests/sync/test_manager.py index 27c026c1..47ac3f01 100644 --- a/tests/sync/test_manager.py +++ b/tests/sync/test_manager.py @@ -1,36 +1,44 @@ """Manager tests.""" -import pytest import threading +import unittest.mock as mock +import time +import pytest -from splitio.tasks.split_sync import SplitSynchronizationTask +from splitio.optional.loaders import asyncio +from splitio.api.auth import AuthAPI +from splitio.api import auth, client, APIException +from splitio.client.util import get_metadata +from splitio.client.config import DEFAULT_CONFIG +from splitio.tasks.split_sync import SplitSynchronizationTask, SplitSynchronizationTaskAsync from splitio.tasks.segment_sync import SegmentSynchronizationTask from splitio.tasks.impressions_sync import ImpressionsSyncTask, ImpressionsCountSyncTask from splitio.tasks.events_sync import EventsSyncTask - -from splitio.sync.split import SplitSynchronizer +from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageProducerAsync +from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemoryTelemetryStorageAsync +from splitio.models.telemetry import SSESyncMode, StreamingEventTypes +from splitio.push.manager import Status +from splitio.sync.split import SplitSynchronizer, SplitSynchronizerAsync from splitio.sync.segment import SegmentSynchronizer from splitio.sync.impression import ImpressionSynchronizer, ImpressionsCountSynchronizer from splitio.sync.event import EventSynchronizer -from splitio.sync.synchronizer import Synchronizer, SplitTasks, SplitSynchronizers -from splitio.sync.manager import Manager - -from splitio.storage import SplitStorage - +from splitio.sync.synchronizer import Synchronizer, SynchronizerAsync, SplitTasks, SplitSynchronizers, RedisSynchronizer, RedisSynchronizerAsync +from splitio.sync.manager import Manager, ManagerAsync, RedisManager, RedisManagerAsync +from splitio.storage import SplitStorage, RuleBasedSegmentsStorage from splitio.api import APIException - from splitio.client.util import SdkMetadata -class ManagerTests(object): +class SyncManagerTests(object): """Synchronizer Manager tests.""" def test_error(self, mocker): split_task = mocker.Mock(spec=SplitSynchronizationTask) split_tasks = SplitTasks(split_task, mocker.Mock(), mocker.Mock(), mocker.Mock(), - mocker.Mock()) + mocker.Mock(), mocker.Mock()) storage = mocker.Mock(spec=SplitStorage) + rb_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) api = mocker.Mock() def run(x): @@ -39,23 +47,194 @@ def run(x): api.fetch_splits.side_effect = run storage.get_change_number.return_value = -1 - split_sync = SplitSynchronizer(api, storage) + split_sync = SplitSynchronizer(api, storage, rb_storage) synchronizers = SplitSynchronizers(split_sync, mocker.Mock(), mocker.Mock(), - mocker.Mock(), mocker.Mock()) + mocker.Mock(), mocker.Mock(), mocker.Mock()) synchronizer = Synchronizer(synchronizers, split_tasks) - manager = Manager(threading.Event(), synchronizer, mocker.Mock(), False, SdkMetadata('1.0', 'some', '1.2.3.4')) + manager = Manager(threading.Event(), synchronizer, mocker.Mock(), False, SdkMetadata('1.0', 'some', '1.2.3.4'), mocker.Mock()) - manager.start() # should not throw! + manager._SYNC_ALL_ATTEMPTS = 1 + manager.start(2) # should not throw! def test_start_streaming_false(self, mocker): splits_ready_event = threading.Event() synchronizer = mocker.Mock(spec=Synchronizer) - manager = Manager(splits_ready_event, synchronizer, mocker.Mock(), False, SdkMetadata('1.0', 'some', '1.2.3.4')) - manager.start() - + manager = Manager(splits_ready_event, synchronizer, mocker.Mock(), False, SdkMetadata('1.0', 'some', '1.2.3.4'), mocker.Mock()) + try: + manager.start() + except: + pass splits_ready_event.wait(2) assert splits_ready_event.is_set() assert len(synchronizer.sync_all.mock_calls) == 1 assert len(synchronizer.start_periodic_fetching.mock_calls) == 1 assert len(synchronizer.start_periodic_data_recording.mock_calls) == 1 + + def test_telemetry(self, mocker): + splits_ready_event = threading.Event() + synchronizer = mocker.Mock(spec=Synchronizer) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_producer = TelemetryStorageProducer(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + manager = Manager(splits_ready_event, synchronizer, mocker.Mock(), True, SdkMetadata('1.0', 'some', '1.2.3.4'), telemetry_runtime_producer) + try: + manager.start() + except: + pass + splits_ready_event.wait(2) + + manager._queue.put(Status.PUSH_SUBSYSTEM_UP) + manager._queue.put(Status.PUSH_NONRETRYABLE_ERROR) + time.sleep(1) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-2]._type == StreamingEventTypes.SYNC_MODE_UPDATE.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-2]._data == SSESyncMode.STREAMING.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.SYNC_MODE_UPDATE.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == SSESyncMode.POLLING.value) + + +class SyncManagerAsyncTests(object): + """Synchronizer Manager tests.""" + + @pytest.mark.asyncio + async def test_error(self, mocker): + split_task = mocker.Mock(spec=SplitSynchronizationTask) + split_tasks = SplitTasks(split_task, mocker.Mock(), mocker.Mock(), mocker.Mock(), + mocker.Mock(), mocker.Mock()) + + storage = mocker.Mock(spec=SplitStorage) + rb_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + api = mocker.Mock() + + async def run(x): + raise APIException("something broke") + api.fetch_splits = run + + async def get_change_number(): + return -1 + storage.get_change_number = get_change_number + + split_sync = SplitSynchronizerAsync(api, storage, rb_storage) + synchronizers = SplitSynchronizers(split_sync, mocker.Mock(), mocker.Mock(), + mocker.Mock(), mocker.Mock(), mocker.Mock()) + + synchronizer = SynchronizerAsync(synchronizers, split_tasks) + manager = ManagerAsync(synchronizer, mocker.Mock(), False, SdkMetadata('1.0', 'some', '1.2.3.4'), mocker.Mock()) + + manager._SYNC_ALL_ATTEMPTS = 1 + await manager.start(2) # should not throw! + + @pytest.mark.asyncio + async def test_start_streaming_false(self, mocker): + synchronizer = mocker.Mock(spec=SynchronizerAsync) + self.sync_all_called = 0 + async def sync_all(retry): + self.sync_all_called += 1 + synchronizer.sync_all = sync_all + + self.fetching_called = 0 + def start_periodic_fetching(): + self.fetching_called += 1 + synchronizer.start_periodic_fetching = start_periodic_fetching + + self.rcording_called = 0 + def start_periodic_data_recording(): + self.rcording_called += 1 + synchronizer.start_periodic_data_recording = start_periodic_data_recording + + manager = ManagerAsync(synchronizer, mocker.Mock(), False, SdkMetadata('1.0', 'some', '1.2.3.4'), mocker.Mock()) + try: + await manager.start() + except: + pass + assert self.sync_all_called == 1 + assert self.fetching_called == 1 + assert self.rcording_called == 1 + + @pytest.mark.asyncio + async def test_telemetry(self, mocker): + synchronizer = mocker.Mock(spec=SynchronizerAsync) + async def sync_all(retry=1): + pass + synchronizer.sync_all = sync_all + + async def stop_periodic_fetching(): + pass + synchronizer.stop_periodic_fetching = stop_periodic_fetching + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) + telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + manager = ManagerAsync(synchronizer, mocker.Mock(), True, SdkMetadata('1.0', 'some', '1.2.3.4'), telemetry_runtime_producer) + try: + await manager.start() + except: + pass + + await manager._queue.put(Status.PUSH_SUBSYSTEM_UP) + await manager._queue.put(Status.PUSH_NONRETRYABLE_ERROR) + await asyncio.sleep(1) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-2]._type == StreamingEventTypes.SYNC_MODE_UPDATE.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-2]._data == SSESyncMode.STREAMING.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._type == StreamingEventTypes.SYNC_MODE_UPDATE.value) + assert(telemetry_storage._streaming_events._streaming_events[len(telemetry_storage._streaming_events._streaming_events)-1]._data == SSESyncMode.POLLING.value) + + +class RedisSyncManagerTests(object): + """Synchronizer Redis Manager tests.""" + + synchronizers = SplitSynchronizers(None, None, None, None, None, None, None, None) + tasks = SplitTasks(None, None, None, None, None, None, None, None) + synchronizer = RedisSynchronizer(synchronizers, tasks) + manager = RedisManager(synchronizer) + + @mock.patch('splitio.sync.synchronizer.RedisSynchronizer.start_periodic_data_recording') + def test_recreate_and_start(self, mocker): + + assert(isinstance(self.manager._synchronizer, RedisSynchronizer)) + + self.manager.recreate() + assert(not mocker.called) + + self.manager.start() + assert(mocker.called) + + @mock.patch('splitio.sync.synchronizer.RedisSynchronizer.shutdown') + def test_recreate_and_stop(self, mocker): + + self.manager.recreate() + assert(not mocker.called) + + self.manager.stop(True) + assert(mocker.called) + + +class RedisSyncManagerAsyncTests(object): + """Synchronizer Redis Manager async tests.""" + + synchronizers = SplitSynchronizers(None, None, None, None, None, None, None, None) + tasks = SplitTasks(None, None, None, None, None, None, None, None) + synchronizer = RedisSynchronizerAsync(synchronizers, tasks) + manager = RedisManagerAsync(synchronizer) + + @mock.patch('splitio.sync.synchronizer.RedisSynchronizerAsync.start_periodic_data_recording') + def test_recreate_and_start(self, mocker): + assert(isinstance(self.manager._synchronizer, RedisSynchronizerAsync)) + + self.manager.recreate() + assert(not mocker.called) + + self.manager.start() + assert(mocker.called) + + @pytest.mark.asyncio + async def test_recreate_and_stop(self, mocker): + self.called = False + async def shutdown(block): + self.called = True + self.manager._synchronizer.shutdown = shutdown + self.manager.recreate() + assert(not self.called) + + await self.manager.stop(True) + assert(self.called) diff --git a/tests/sync/test_segments_synchronizer.py b/tests/sync/test_segments_synchronizer.py index 1b4c9539..5b405ef8 100644 --- a/tests/sync/test_segments_synchronizer.py +++ b/tests/sync/test_segments_synchronizer.py @@ -1,11 +1,18 @@ """Split Worker tests.""" +import os + from splitio.util.backoff import Backoff from splitio.api import APIException from splitio.api.commons import FetchOptions -from splitio.storage import SplitStorage, SegmentStorage +from splitio.storage import SplitStorage, SegmentStorage, RuleBasedSegmentsStorage +from splitio.storage.inmemmory import InMemorySegmentStorage, InMemorySegmentStorageAsync, InMemorySplitStorage, InMemorySplitStorageAsync +from splitio.sync.segment import SegmentSynchronizer, SegmentSynchronizerAsync, LocalSegmentSynchronizer, LocalSegmentSynchronizerAsync from splitio.models.segments import Segment +from splitio.models import rule_based_segments +from splitio.optional.loaders import aiofiles, asyncio +import pytest class SegmentsSynchronizerTests(object): """Segments synchronizer test cases.""" @@ -17,6 +24,8 @@ def test_synchronize_segments_error(self, mocker): storage = mocker.Mock(spec=SegmentStorage) storage.get_change_number.return_value = -1 + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + rbs_storage.get_segment_names.return_value = [] api = mocker.Mock() @@ -24,8 +33,7 @@ def run(x): raise APIException("something broke") api.fetch_segment.side_effect = run - from splitio.sync.segment import SegmentSynchronizer - segments_synchronizer = SegmentSynchronizer(api, split_storage, storage) + segments_synchronizer = SegmentSynchronizer(api, split_storage, storage, rbs_storage) assert not segments_synchronizer.synchronize_segments() def test_synchronize_segments(self, mocker): @@ -33,6 +41,10 @@ def test_synchronize_segments(self, mocker): split_storage = mocker.Mock(spec=SplitStorage) split_storage.get_segment_names.return_value = ['segmentA', 'segmentB', 'segmentC'] + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + rbs_storage.get_segment_names.return_value = ['rbs'] + rbs_storage.get.return_value = rule_based_segments.from_raw({'name': 'rbs', 'conditions': [], 'trafficTypeName': 'user', 'changeNumber': 123, 'status': 'ACTIVE', 'excluded': {'keys': [], 'segments': [{'type': 'standard', 'name': 'segmentD'}]}}) + # Setup a mocked segment storage whose changenumber returns -1 on first fetch and # 123 afterwards. storage = mocker.Mock(spec=SegmentStorage) @@ -47,10 +59,14 @@ def change_number_mock(segment_name): if segment_name == 'segmentC' and change_number_mock._count_c == 0: change_number_mock._count_c = 1 return -1 + if segment_name == 'segmentD' and change_number_mock._count_d == 0: + change_number_mock._count_d = 1 + return -1 return 123 change_number_mock._count_a = 0 change_number_mock._count_b = 0 change_number_mock._count_c = 0 + change_number_mock._count_d = 0 storage.get_change_number.side_effect = change_number_mock # Setup a mocked segment api to return segments mentioned before. @@ -67,28 +83,35 @@ def fetch_segment_mock(segment_name, change_number, fetch_options): fetch_segment_mock._count_c = 1 return {'name': 'segmentC', 'added': ['key7', 'key8', 'key9'], 'removed': [], 'since': -1, 'till': 123} + if segment_name == 'segmentD' and fetch_segment_mock._count_d == 0: + fetch_segment_mock._count_d = 1 + return {'name': 'segmentD', 'added': ['key10'], 'removed': [], + 'since': -1, 'till': 123} return {'added': [], 'removed': [], 'since': 123, 'till': 123} fetch_segment_mock._count_a = 0 fetch_segment_mock._count_b = 0 fetch_segment_mock._count_c = 0 + fetch_segment_mock._count_d = 0 api = mocker.Mock() api.fetch_segment.side_effect = fetch_segment_mock - from splitio.sync.segment import SegmentSynchronizer - segments_synchronizer = SegmentSynchronizer(api, split_storage, storage) + segments_synchronizer = SegmentSynchronizer(api, split_storage, storage, rbs_storage) assert segments_synchronizer.synchronize_segments() api_calls = [call for call in api.fetch_segment.mock_calls] - assert mocker.call('segmentA', -1, FetchOptions(True)) in api_calls - assert mocker.call('segmentB', -1, FetchOptions(True)) in api_calls - assert mocker.call('segmentC', -1, FetchOptions(True)) in api_calls - assert mocker.call('segmentA', 123, FetchOptions(True)) in api_calls - assert mocker.call('segmentB', 123, FetchOptions(True)) in api_calls - assert mocker.call('segmentC', 123, FetchOptions(True)) in api_calls + + assert mocker.call('segmentA', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert mocker.call('segmentB', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert mocker.call('segmentC', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert mocker.call('segmentD', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert mocker.call('segmentA', 123, FetchOptions(True, None, None, None, None)) in api_calls + assert mocker.call('segmentB', 123, FetchOptions(True, None, None, None, None)) in api_calls + assert mocker.call('segmentC', 123, FetchOptions(True, None, None, None, None)) in api_calls + assert mocker.call('segmentD', 123, FetchOptions(True, None, None, None, None)) in api_calls segment_put_calls = storage.put.mock_calls - segments_to_validate = set(['segmentA', 'segmentB', 'segmentC']) + segments_to_validate = set(['segmentA', 'segmentB', 'segmentC', 'segmentD']) for call in segment_put_calls: _, positional_args, _ = call segment = positional_args[0] @@ -100,6 +123,8 @@ def test_synchronize_segment(self, mocker): """Test particular segment update.""" split_storage = mocker.Mock(spec=SplitStorage) storage = mocker.Mock(spec=SegmentStorage) + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + rbs_storage.get_segment_names.return_value = [] def change_number_mock(segment_name): if change_number_mock._count_a == 0: @@ -120,21 +145,21 @@ def fetch_segment_mock(segment_name, change_number, fetch_options): api = mocker.Mock() api.fetch_segment.side_effect = fetch_segment_mock - from splitio.sync.segment import SegmentSynchronizer - segments_synchronizer = SegmentSynchronizer(api, split_storage, storage) + segments_synchronizer = SegmentSynchronizer(api, split_storage, storage, rbs_storage) segments_synchronizer.synchronize_segment('segmentA') api_calls = [call for call in api.fetch_segment.mock_calls] - assert mocker.call('segmentA', -1, FetchOptions(True)) in api_calls - assert mocker.call('segmentA', 123, FetchOptions(True)) in api_calls + assert mocker.call('segmentA', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert mocker.call('segmentA', 123, FetchOptions(True, None, None, None, None)) in api_calls def test_synchronize_segment_cdn(self, mocker): """Test particular segment update cdn bypass.""" mocker.patch('splitio.sync.segment._ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES', new=3) - from splitio.sync.segment import SegmentSynchronizer split_storage = mocker.Mock(spec=SplitStorage) storage = mocker.Mock(spec=SegmentStorage) + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + rbs_storage.get_segment_names.return_value = [] def change_number_mock(segment_name): change_number_mock._count_a += 1 @@ -168,21 +193,588 @@ def fetch_segment_mock(segment_name, change_number, fetch_options): api = mocker.Mock() api.fetch_segment.side_effect = fetch_segment_mock - segments_synchronizer = SegmentSynchronizer(api, split_storage, storage) + segments_synchronizer = SegmentSynchronizer(api, split_storage, storage, rbs_storage) segments_synchronizer.synchronize_segment('segmentA') - assert mocker.call('segmentA', -1, FetchOptions(True)) in api.fetch_segment.mock_calls - assert mocker.call('segmentA', 123, FetchOptions(True)) in api.fetch_segment.mock_calls + assert mocker.call('segmentA', -1, FetchOptions(True, None, None, None, None)) in api.fetch_segment.mock_calls + assert mocker.call('segmentA', 123, FetchOptions(True, None, None, None, None)) in api.fetch_segment.mock_calls segments_synchronizer._backoff = Backoff(1, 0.1) segments_synchronizer.synchronize_segment('segmentA', 12345) - assert mocker.call('segmentA', 12345, FetchOptions(True, 1234)) in api.fetch_segment.mock_calls + assert mocker.call('segmentA', 12345, FetchOptions(True, 1234, None, None, None)) in api.fetch_segment.mock_calls assert len(api.fetch_segment.mock_calls) == 8 # 2 ok + BACKOFF(2 since==till + 2 re-attempts) + CDN(2 since==till) def test_recreate(self, mocker): """Test recreate logic.""" - from splitio.sync.segment import SegmentSynchronizer - segments_synchronizer = SegmentSynchronizer(mocker.Mock(), mocker.Mock(), mocker.Mock()) + segments_synchronizer = SegmentSynchronizer(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + current_pool = segments_synchronizer._worker_pool + segments_synchronizer.recreate() + assert segments_synchronizer._worker_pool != current_pool + + +class SegmentsSynchronizerAsyncTests(object): + """Segments synchronizer async test cases.""" + + @pytest.mark.asyncio + async def test_synchronize_segments_error(self, mocker): + """On error.""" + split_storage = mocker.Mock(spec=SplitStorage) + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + + async def get_segment_names_rbs(): + return [] + rbs_storage.get_segment_names = get_segment_names_rbs + + async def get_segment_names(): + return ['segmentA', 'segmentB', 'segmentC'] + split_storage.get_segment_names = get_segment_names + + storage = mocker.Mock(spec=SegmentStorage) + async def get_change_number(*args): + return -1 + storage.get_change_number = get_change_number + + async def put(*args): + pass + storage.put = put + + api = mocker.Mock() + async def run(*args): + raise APIException("something broke") + api.fetch_segment = run + + segments_synchronizer = SegmentSynchronizerAsync(api, split_storage, storage, rbs_storage) + assert not await segments_synchronizer.synchronize_segments() + await segments_synchronizer.shutdown() + + @pytest.mark.asyncio + async def test_synchronize_segments(self, mocker): + """Test the normal operation flow.""" + split_storage = mocker.Mock(spec=SplitStorage) + async def get_segment_names(): + return ['segmentA', 'segmentB', 'segmentC'] + split_storage.get_segment_names = get_segment_names + + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + async def get_segment_names_rbs(): + return ['rbs'] + rbs_storage.get_segment_names = get_segment_names_rbs + + async def get_rbs(segment_name): + return rule_based_segments.from_raw({'name': 'rbs', 'conditions': [], 'trafficTypeName': 'user', 'changeNumber': 123, 'status': 'ACTIVE', 'excluded': {'keys': [], 'segments': [{'type': 'standard', 'name': 'segmentD'}]}}) + rbs_storage.get = get_rbs + + # Setup a mocked segment storage whose changenumber returns -1 on first fetch and + # 123 afterwards. + storage = mocker.Mock(spec=SegmentStorage) + + async def change_number_mock(segment_name): + if segment_name == 'segmentA' and change_number_mock._count_a == 0: + change_number_mock._count_a = 1 + return -1 + if segment_name == 'segmentB' and change_number_mock._count_b == 0: + change_number_mock._count_b = 1 + return -1 + if segment_name == 'segmentC' and change_number_mock._count_c == 0: + change_number_mock._count_c = 1 + return -1 + if segment_name == 'segmentD' and change_number_mock._count_d == 0: + change_number_mock._count_d = 1 + return -1 + return 123 + change_number_mock._count_a = 0 + change_number_mock._count_b = 0 + change_number_mock._count_c = 0 + change_number_mock._count_d = 0 + storage.get_change_number = change_number_mock + + self.segment_put = [] + async def put(segment): + self.segment_put.append(segment) + storage.put = put + + async def update(*args): + pass + storage.update = update + + # Setup a mocked segment api to return segments mentioned before. + self.options = [] + self.segment = [] + self.change = [] + async def fetch_segment_mock(segment_name, change_number, fetch_options): + self.segment.append(segment_name) + self.options.append(fetch_options) + self.change.append(change_number) + if segment_name == 'segmentA' and fetch_segment_mock._count_a == 0: + fetch_segment_mock._count_a = 1 + return {'name': 'segmentA', 'added': ['key1', 'key2', 'key3'], 'removed': [], + 'since': -1, 'till': 123} + if segment_name == 'segmentB' and fetch_segment_mock._count_b == 0: + fetch_segment_mock._count_b = 1 + return {'name': 'segmentB', 'added': ['key4', 'key5', 'key6'], 'removed': [], + 'since': -1, 'till': 123} + if segment_name == 'segmentC' and fetch_segment_mock._count_c == 0: + fetch_segment_mock._count_c = 1 + return {'name': 'segmentC', 'added': ['key7', 'key8', 'key9'], 'removed': [], + 'since': -1, 'till': 123} + if segment_name == 'segmentD' and fetch_segment_mock._count_d == 0: + fetch_segment_mock._count_d = 1 + return {'name': 'segmentD', 'added': ['key10'], 'removed': [], + 'since': -1, 'till': 123} + return {'added': [], 'removed': [], 'since': 123, 'till': 123} + fetch_segment_mock._count_a = 0 + fetch_segment_mock._count_b = 0 + fetch_segment_mock._count_c = 0 + fetch_segment_mock._count_d = 0 + + api = mocker.Mock() + api.fetch_segment = fetch_segment_mock + + segments_synchronizer = SegmentSynchronizerAsync(api, split_storage, storage, rbs_storage) + assert await segments_synchronizer.synchronize_segments() + + api_calls = [] + for i in range(8): + api_calls.append((self.segment[i], self.change[i], self.options[i])) + + assert ('segmentD', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentD', 123, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentA', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentA', 123, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentB', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentB', 123, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentC', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentC', 123, FetchOptions(True, None, None, None, None)) in api_calls + + segments_to_validate = set(['segmentA', 'segmentB', 'segmentC', 'segmentD']) + for segment in self.segment_put: + assert isinstance(segment, Segment) + assert segment.name in segments_to_validate + segments_to_validate.remove(segment.name) + + await segments_synchronizer.shutdown() + + @pytest.mark.asyncio + async def test_synchronize_segment(self, mocker): + """Test particular segment update.""" + split_storage = mocker.Mock(spec=SplitStorage) + storage = mocker.Mock(spec=SegmentStorage) + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + + async def get_segment_names_rbs(): + return [] + rbs_storage.get_segment_names = get_segment_names_rbs + + async def change_number_mock(segment_name): + if change_number_mock._count_a == 0: + change_number_mock._count_a = 1 + return -1 + return 123 + change_number_mock._count_a = 0 + storage.get_change_number = change_number_mock + async def put(segment): + pass + storage.put = put + + async def update(*args): + pass + storage.update = update + + self.options = [] + self.segment = [] + self.change = [] + async def fetch_segment_mock(segment_name, change_number, fetch_options): + self.segment.append(segment_name) + self.options.append(fetch_options) + self.change.append(change_number) + if fetch_segment_mock._count_a == 0: + fetch_segment_mock._count_a = 1 + return {'name': 'segmentA', 'added': ['key1', 'key2', 'key3'], 'removed': [], + 'since': -1, 'till': 123} + return {'added': [], 'removed': [], 'since': 123, 'till': 123} + fetch_segment_mock._count_a = 0 + + api = mocker.Mock() + api.fetch_segment = fetch_segment_mock + + segments_synchronizer = SegmentSynchronizerAsync(api, split_storage, storage, rbs_storage) + await segments_synchronizer.synchronize_segment('segmentA') + + assert (self.segment[0], self.change[0], self.options[0]) == ('segmentA', -1, FetchOptions(True, None, None, None, None)) + assert (self.segment[1], self.change[1], self.options[1]) == ('segmentA', 123, FetchOptions(True, None, None, None, None)) + + await segments_synchronizer.shutdown() + + @pytest.mark.asyncio + async def test_synchronize_segment_cdn(self, mocker): + """Test particular segment update cdn bypass.""" + mocker.patch('splitio.sync.segment._ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES', new=3) + + split_storage = mocker.Mock(spec=SplitStorage) + storage = mocker.Mock(spec=SegmentStorage) + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + + async def get_segment_names_rbs(): + return [] + rbs_storage.get_segment_names = get_segment_names_rbs + + async def change_number_mock(segment_name): + change_number_mock._count_a += 1 + if change_number_mock._count_a == 1: + return -1 + elif change_number_mock._count_a >= 2 and change_number_mock._count_a <= 3: + return 123 + elif change_number_mock._count_a <= 7: + return 1234 + return 12345 # Return proper cn for CDN Bypass + change_number_mock._count_a = 0 + storage.get_change_number = change_number_mock + async def put(segment): + pass + storage.put = put + + async def update(*args): + pass + storage.update = update + + self.options = [] + self.segment = [] + self.change = [] + async def fetch_segment_mock(segment_name, change_number, fetch_options): + self.segment.append(segment_name) + self.options.append(fetch_options) + self.change.append(change_number) + fetch_segment_mock._count_a += 1 + if fetch_segment_mock._count_a == 1: + return {'name': 'segmentA', 'added': ['key1', 'key2', 'key3'], 'removed': [], + 'since': -1, 'till': 123} + elif fetch_segment_mock._count_a == 2: + return {'added': [], 'removed': [], 'since': 123, 'till': 123} + elif fetch_segment_mock._count_a == 3: + return {'added': [], 'removed': [], 'since': 123, 'till': 1234} + elif fetch_segment_mock._count_a >= 4 and fetch_segment_mock._count_a <= 6: + return {'added': [], 'removed': [], 'since': 1234, 'till': 1234} + elif fetch_segment_mock._count_a == 7: + return {'added': [], 'removed': [], 'since': 1234, 'till': 12345} + return {'added': [], 'removed': [], 'since': 12345, 'till': 12345} + fetch_segment_mock._count_a = 0 + + api = mocker.Mock() + api.fetch_segment = fetch_segment_mock + + segments_synchronizer = SegmentSynchronizerAsync(api, split_storage, storage, rbs_storage) + await segments_synchronizer.synchronize_segment('segmentA') + + assert (self.segment[0], self.change[0], self.options[0]) == ('segmentA', -1, FetchOptions(True, None, None, None, None)) + assert (self.segment[1], self.change[1], self.options[1]) == ('segmentA', 123, FetchOptions(True, None, None, None, None)) + + segments_synchronizer._backoff = Backoff(1, 0.1) + await segments_synchronizer.synchronize_segment('segmentA', 12345) + assert (self.segment[7], self.change[7], self.options[7]) == ('segmentA', 12345, FetchOptions(True, 1234, None, None, None)) + assert len(self.segment) == 8 # 2 ok + BACKOFF(2 since==till + 2 re-attempts) + CDN(2 since==till) + await segments_synchronizer.shutdown() + + @pytest.mark.asyncio + async def test_recreate(self, mocker): + """Test recreate logic.""" + segments_synchronizer = SegmentSynchronizerAsync(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) current_pool = segments_synchronizer._worker_pool + await segments_synchronizer.shutdown() segments_synchronizer.recreate() + assert segments_synchronizer._worker_pool != current_pool + await segments_synchronizer.shutdown() + + +class LocalSegmentsSynchronizerTests(object): + """Segments synchronizer test cases.""" + + def test_synchronize_segments_error(self, mocker): + """On error.""" + split_storage = mocker.Mock(spec=SplitStorage) + split_storage.get_segment_names.return_value = ['segmentA', 'segmentB', 'segmentC'] + + storage = mocker.Mock(spec=SegmentStorage) + storage.get_change_number.return_value = -1 + + segments_synchronizer = LocalSegmentSynchronizer('/,/,/invalid folder name/,/,/', split_storage, storage) + assert not segments_synchronizer.synchronize_segments() + + def test_synchronize_segments(self, mocker): + """Test the normal operation flow.""" + split_storage = mocker.Mock(spec=InMemorySplitStorage) + split_storage.get_segment_names.return_value = ['segmentA', 'segmentB', 'segmentC'] + events_queue = queue.Queue() + storage = InMemorySegmentStorage(events_queue) + + segment_a = {'name': 'segmentA', 'added': ['key1', 'key2', 'key3'], 'removed': [], + 'since': -1, 'till': 123} + segment_b = {'name': 'segmentB', 'added': ['key4', 'key5', 'key6'], 'removed': [], + 'since': -1, 'till': 123} + segment_c = {'name': 'segmentC', 'added': ['key7', 'key8', 'key9'], 'removed': [], + 'since': -1, 'till': 123} + blank = {'added': [], 'removed': [], 'since': 123, 'till': 123} + + def read_segment_from_json_file(*args, **kwargs): + if args[0] == 'segmentA': + return segment_a + if args[0] == 'segmentB': + return segment_b + if args[0] == 'segmentC': + return segment_c + return blank + + segments_synchronizer = LocalSegmentSynchronizer('segment_path', split_storage, storage) + segments_synchronizer._read_segment_from_json_file = read_segment_from_json_file + assert segments_synchronizer.synchronize_segments() + + segment = storage.get('segmentA') + assert segment.name == 'segmentA' + assert segment.contains('key1') + assert segment.contains('key2') + assert segment.contains('key3') + + segment = storage.get('segmentB') + assert segment.name == 'segmentB' + assert segment.contains('key4') + assert segment.contains('key5') + assert segment.contains('key6') + + segment = storage.get('segmentC') + assert segment.name == 'segmentC' + assert segment.contains('key7') + assert segment.contains('key8') + assert segment.contains('key9') + + # Should sync when changenumber is not changed + segment_a['added'] = ['key111'] + segments_synchronizer.synchronize_segments(['segmentA']) + segment = storage.get('segmentA') + assert segment.contains('key111') + + # Should not sync when changenumber below till + segment_a['till'] = 122 + segment_a['added'] = ['key222'] + segments_synchronizer.synchronize_segments(['segmentA']) + segment = storage.get('segmentA') + assert not segment.contains('key222') + + # Should sync when changenumber above till + segment_a['till'] = 124 + segments_synchronizer.synchronize_segments(['segmentA']) + segment = storage.get('segmentA') + assert segment.contains('key222') + + # Should sync when till is default (-1) + segment_a['till'] = -1 + segment_a['added'] = ['key33'] + segments_synchronizer.synchronize_segments(['segmentA']) + segment = storage.get('segmentA') + assert segment.contains('key33') + + # verify remove keys + segment_a['added'] = [] + segment_a['removed'] = ['key111'] + segment_a['till'] = 125 + segments_synchronizer.synchronize_segments(['segmentA']) + segment = storage.get('segmentA') + assert not segment.contains('key111') + + def test_reading_json(self, mocker): + """Test reading json file.""" + f = open("./segmentA.json", "w") + f.write('{"name": "segmentA", "added": ["key1", "key2", "key3"], "removed": [],"since": -1, "till": 123}') + f.close() + split_storage = mocker.Mock(spec=InMemorySplitStorage) + events_queue = queue.Queue() + storage = InMemorySegmentStorage(events_queue) + segments_synchronizer = LocalSegmentSynchronizer('.', split_storage, storage) + assert segments_synchronizer.synchronize_segments(['segmentA']) + + segment = storage.get('segmentA') + assert segment.name == 'segmentA' + assert segment.contains('key1') + assert segment.contains('key2') + assert segment.contains('key3') + + os.remove("./segmentA.json") + + def test_json_elements_sanitization(self, mocker): + """Test sanitization.""" + segment_synchronizer = LocalSegmentSynchronizer(mocker.Mock(), mocker.Mock(), mocker.Mock()) + segment1 = {"name": 'seg', "added": [], "removed": [], "since": -1, "till": 12} + + # should reject segment if 'name' is null + segment2 = {"name": None, "added": [], "removed": [], "since": -1, "till": 12} + exception_called = False + try: + segment_synchronizer._sanitize_segment(segment2) + except: + exception_called = True + assert(exception_called) + + # should reject segment if 'name' does not exist + segment2 = {"added": [], "removed": [], "since": -1, "till": 12} + exception_called = False + try: + segment_synchronizer._sanitize_segment(segment2) + except: + exception_called = True + assert(exception_called) + + # should add missing 'added' element + segment2 = {"name": 'seg', "removed": [], "since": -1, "till": 12} + assert(segment_synchronizer._sanitize_segment(segment2) == segment1) + + # should add missing 'removed' element + segment2 = {"name": 'seg', "added": [], "since": -1, "till": 12} + assert(segment_synchronizer._sanitize_segment(segment2) == segment1) + + # should reset added and remved to array if values are None + segment2 = {"name": 'seg', "added": None, "removed": None, "since": -1, "till": 12} + assert(segment_synchronizer._sanitize_segment(segment2) == segment1) + + # should reset since and till to -1 if values are None + segment3 = segment1.copy() + segment3["till"] = -1 + segment2 = {"name": 'seg', "added": [], "removed": [], "since": None, "till": None} + assert(segment_synchronizer._sanitize_segment(segment2) == segment3) + + # should add since and till with -1 if they are missing + segment2 = {"name": 'seg', "added": [], "removed": []} + assert(segment_synchronizer._sanitize_segment(segment2) == segment3) + + # should reset since and till to -1 if values are 0 + segment2 = {"name": 'seg', "added": [], "removed": [], "since": 0, "till": 0} + assert(segment_synchronizer._sanitize_segment(segment2) == segment3) + + # should reset till and since to -1 if values below -1 + segment2 = {"name": 'seg', "added": [], "removed": [], "since": -2, "till": -2} + assert(segment_synchronizer._sanitize_segment(segment2) == segment3) + + # should reset since to till if value above till + segment3["since"] = 12 + segment3["till"] = 12 + segment2 = {"name": 'seg', "added": [], "removed": [], "since": 20, "till": 12} + assert(segment_synchronizer._sanitize_segment(segment2) == segment3) + + +class LocalSegmentsSynchronizerTests(object): + """Segments synchronizer test cases.""" + + @pytest.mark.asyncio + async def test_synchronize_segments_error(self, mocker): + """On error.""" + split_storage = mocker.Mock(spec=SplitStorage) + async def get_segment_names(): + return ['segmentA', 'segmentB', 'segmentC'] + split_storage.get_segment_names = get_segment_names + + storage = mocker.Mock(spec=SegmentStorage) + async def get_change_number(): + return -1 + storage.get_change_number = get_change_number + + segments_synchronizer = LocalSegmentSynchronizerAsync('/,/,/invalid folder name/,/,/', split_storage, storage) + assert not await segments_synchronizer.synchronize_segments() + + @pytest.mark.asyncio + async def test_synchronize_segments(self, mocker): + """Test the normal operation flow.""" + split_storage = mocker.Mock(spec=InMemorySplitStorage) + async def get_segment_names(): + return ['segmentA', 'segmentB', 'segmentC'] + split_storage.get_segment_names = get_segment_names + + storage = InMemorySegmentStorageAsync(asyncio.Queue()) + + segment_a = {'name': 'segmentA', 'added': ['key1', 'key2', 'key3'], 'removed': [], + 'since': -1, 'till': 123} + segment_b = {'name': 'segmentB', 'added': ['key4', 'key5', 'key6'], 'removed': [], + 'since': -1, 'till': 123} + segment_c = {'name': 'segmentC', 'added': ['key7', 'key8', 'key9'], 'removed': [], + 'since': -1, 'till': 123} + blank = {'added': [], 'removed': [], 'since': 123, 'till': 123} + + async def read_segment_from_json_file(*args, **kwargs): + if args[0] == 'segmentA': + return segment_a + if args[0] == 'segmentB': + return segment_b + if args[0] == 'segmentC': + return segment_c + return blank + + segments_synchronizer = LocalSegmentSynchronizerAsync('segment_path', split_storage, storage) + segments_synchronizer._read_segment_from_json_file = read_segment_from_json_file + assert await segments_synchronizer.synchronize_segments() + + segment = await storage.get('segmentA') + assert segment.name == 'segmentA' + assert segment.contains('key1') + assert segment.contains('key2') + assert segment.contains('key3') + + segment = await storage.get('segmentB') + assert segment.name == 'segmentB' + assert segment.contains('key4') + assert segment.contains('key5') + assert segment.contains('key6') + + segment = await storage.get('segmentC') + assert segment.name == 'segmentC' + assert segment.contains('key7') + assert segment.contains('key8') + assert segment.contains('key9') + + # Should sync when changenumber is not changed + segment_a['added'] = ['key111'] + await segments_synchronizer.synchronize_segments(['segmentA']) + segment = await storage.get('segmentA') + assert segment.contains('key111') + + # Should not sync when changenumber below till + segment_a['till'] = 122 + segment_a['added'] = ['key222'] + await segments_synchronizer.synchronize_segments(['segmentA']) + segment = await storage.get('segmentA') + assert not segment.contains('key222') + + # Should sync when changenumber above till + segment_a['till'] = 124 + await segments_synchronizer.synchronize_segments(['segmentA']) + segment = await storage.get('segmentA') + assert segment.contains('key222') + + # Should sync when till is default (-1) + segment_a['till'] = -1 + segment_a['added'] = ['key33'] + await segments_synchronizer.synchronize_segments(['segmentA']) + segment = await storage.get('segmentA') + assert segment.contains('key33') + + # verify remove keys + segment_a['added'] = [] + segment_a['removed'] = ['key111'] + segment_a['till'] = 125 + await segments_synchronizer.synchronize_segments(['segmentA']) + segment = await storage.get('segmentA') + assert not segment.contains('key111') + + @pytest.mark.asyncio + async def test_reading_json(self, mocker): + """Test reading json file.""" + async with aiofiles.open("./segmentA.json", "w") as f: + await f.write('{"name": "segmentA", "added": ["key1", "key2", "key3"], "removed": [],"since": -1, "till": 123}') + split_storage = mocker.Mock(spec=InMemorySplitStorageAsync) + storage = InMemorySegmentStorageAsync(asyncio.Queue()) + segments_synchronizer = LocalSegmentSynchronizerAsync('.', split_storage, storage) + assert await segments_synchronizer.synchronize_segments(['segmentA']) + + segment = await storage.get('segmentA') + assert segment.name == 'segmentA' + assert segment.contains('key1') + assert segment.contains('key2') + assert segment.contains('key3') + + os.remove("./segmentA.json") \ No newline at end of file diff --git a/tests/sync/test_splits_synchronizer.py b/tests/sync/test_splits_synchronizer.py index 3b295d5b..b27606a4 100644 --- a/tests/sync/test_splits_synchronizer.py +++ b/tests/sync/test_splits_synchronizer.py @@ -1,118 +1,292 @@ """Split Worker tests.""" import pytest +import os +import json +import copy +import queue from splitio.util.backoff import Backoff from splitio.api import APIException from splitio.api.commons import FetchOptions -from splitio.storage import SplitStorage +from splitio.storage import SplitStorage, RuleBasedSegmentsStorage +from splitio.storage.inmemmory import InMemorySplitStorage, InMemorySplitStorageAsync, InMemoryRuleBasedSegmentStorage, InMemoryRuleBasedSegmentStorageAsync +from splitio.storage import FlagSetsFilter from splitio.models.splits import Split +from splitio.models.rule_based_segments import RuleBasedSegment +from splitio.sync.split import SplitSynchronizer, SplitSynchronizerAsync, LocalSplitSynchronizer, LocalSplitSynchronizerAsync, LocalhostMode +from splitio.optional.loaders import aiofiles, asyncio +from tests.integration import splits_json, rbsegments_json +splits_raw = [{ + 'changeNumber': 123, + 'trafficTypeName': 'user', + 'name': 'some_name', + 'trafficAllocation': 100, + 'trafficAllocationSeed': 123456, + 'seed': 321654, + 'status': 'ACTIVE', + 'killed': False, + 'defaultTreatment': 'off', + 'algo': 2, + 'conditions': [ + { + 'partitions': [ + {'treatment': 'on', 'size': 50}, + {'treatment': 'off', 'size': 50} + ], + 'contitionType': 'WHITELIST', + 'label': 'some_label', + 'matcherGroup': { + 'matchers': [ + { + 'matcherType': 'WHITELIST', + 'whitelistMatcherData': { + 'whitelist': ['k1', 'k2', 'k3'] + }, + 'negate': False, + } + ], + 'combiner': 'AND' + } + } + ], + 'sets': ['set1', 'set2'] +}] + +json_body = { + "ff": { + "t":1675095324253, + "s":-1, + 'd': [{ + 'changeNumber': 123, + 'trafficTypeName': 'user', + 'name': 'some_name', + 'trafficAllocation': 100, + 'trafficAllocationSeed': 123456, + 'seed': 321654, + 'status': 'ACTIVE', + 'killed': False, + 'defaultTreatment': 'off', + 'algo': 2, + 'conditions': [ + { + 'partitions': [ + {'treatment': 'on', 'size': 50}, + {'treatment': 'off', 'size': 50} + ], + 'contitionType': 'WHITELIST', + 'label': 'some_label', + 'matcherGroup': { + 'matchers': [ + { + 'matcherType': 'WHITELIST', + 'whitelistMatcherData': { + 'whitelist': ['k1', 'k2', 'k3'] + }, + 'negate': False, + } + ], + 'combiner': 'AND' + } + }, + { + "conditionType": "ROLLOUT", + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user" + }, + "matcherType": "IN_RULE_BASED_SEGMENT", + "negate": False, + "userDefinedSegmentMatcherData": { + "segmentName": "sample_rule_based_segment" + } + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ], + "label": "in rule based segment sample_rule_based_segment" + }, + ], + 'sets': ['set1', 'set2']}] + }, + "rbs": { + "t": 1675095324253, + "s": -1, + "d": [ + { + "changeNumber": 5, + "name": "sample_rule_based_segment", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded":{ + "keys":["mauro@split.io","gaston@split.io"], + "segments":[] + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "keySelector": { + "trafficType": "user", + "attribute": "email" + }, + "matcherType": "ENDS_WITH", + "negate": False, + "whitelistMatcherData": { + "whitelist": [ + "@split.io" + ] + } + } + ] + } + } + ] + } + ] + } +} class SplitsSynchronizerTests(object): """Split synchronizer test cases.""" + splits = copy.deepcopy(splits_raw) + def test_synchronize_splits_error(self, mocker): """Test that if fetching splits fails at some_point, the task will continue running.""" - storage = mocker.Mock(spec=SplitStorage) + storage = mocker.Mock(spec=InMemorySplitStorage) + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorage) api = mocker.Mock() - def run(x, c): + def run(x, y, c): raise APIException("something broke") run._calls = 0 api.fetch_splits.side_effect = run storage.get_change_number.return_value = -1 + rbs_storage.get_change_number.return_value = -1 + + class flag_set_filter(): + def should_filter(): + return False + + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + storage.flag_set_filter.sorted_flag_sets = [] - from splitio.sync.split import SplitSynchronizer - split_synchronizer = SplitSynchronizer(api, storage) + split_synchronizer = SplitSynchronizer(api, storage, rbs_storage) with pytest.raises(APIException): split_synchronizer.synchronize_splits(1) def test_synchronize_splits(self, mocker): """Test split sync.""" - storage = mocker.Mock(spec=SplitStorage) + storage = mocker.Mock(spec=InMemorySplitStorage) + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorage) def change_number_mock(): change_number_mock._calls += 1 if change_number_mock._calls == 1: return -1 return 123 + + def rbs_change_number_mock(): + rbs_change_number_mock._calls += 1 + if rbs_change_number_mock._calls == 1: + return -1 + return 123 + change_number_mock._calls = 0 + rbs_change_number_mock._calls = 0 storage.get_change_number.side_effect = change_number_mock + rbs_storage.get_change_number.side_effect = rbs_change_number_mock + + class flag_set_filter(): + def should_filter(): + return False - api = mocker.Mock() - splits = [{ - 'changeNumber': 123, - 'trafficTypeName': 'user', - 'name': 'some_name', - 'trafficAllocation': 100, - 'trafficAllocationSeed': 123456, - 'seed': 321654, - 'status': 'ACTIVE', - 'killed': False, - 'defaultTreatment': 'off', - 'algo': 2, - 'conditions': [ - { - 'partitions': [ - {'treatment': 'on', 'size': 50}, - {'treatment': 'off', 'size': 50} - ], - 'contitionType': 'WHITELIST', - 'label': 'some_label', - 'matcherGroup': { - 'matchers': [ - { - 'matcherType': 'WHITELIST', - 'whitelistMatcherData': { - 'whitelist': ['k1', 'k2', 'k3'] - }, - 'negate': False, - } - ], - 'combiner': 'AND' - } - } - ] - }] + def intersect(sets): + return True + + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + storage.flag_set_filter.sorted_flag_sets = [] + api = mocker.Mock() def get_changes(*args, **kwargs): get_changes.called += 1 if get_changes.called == 1: - return { - 'splits': splits, - 'since': -1, - 'till': 123 - } + return json_body else: return { - 'splits': [], - 'since': 123, - 'till': 123 + "ff": { + "t":123, + "s":123, + 'd': [] + }, + "rbs": { + "t": 5, + "s": 5, + "d": [] + } } + get_changes.called = 0 api.fetch_splits.side_effect = get_changes - from splitio.sync.split import SplitSynchronizer - split_synchronizer = SplitSynchronizer(api, storage) + split_synchronizer = SplitSynchronizer(api, storage, rbs_storage) split_synchronizer.synchronize_splits() + + assert api.fetch_splits.mock_calls[0][1][0] == -1 + assert api.fetch_splits.mock_calls[0][1][2].cache_control_headers == True + assert api.fetch_splits.mock_calls[1][1][0] == 123 + assert api.fetch_splits.mock_calls[1][1][1] == 123 + assert api.fetch_splits.mock_calls[1][1][2].cache_control_headers == True - assert mocker.call(-1, FetchOptions(True)) in api.fetch_splits.mock_calls - assert mocker.call(123, FetchOptions(True)) in api.fetch_splits.mock_calls - - inserted_split = storage.put.mock_calls[0][1][0] + inserted_split = storage.update.mock_calls[0][1][0][0] assert isinstance(inserted_split, Split) assert inserted_split.name == 'some_name' + inserted_rbs = rbs_storage.update.mock_calls[0][1][0][0] + assert isinstance(inserted_rbs, RuleBasedSegment) + assert inserted_rbs.name == 'sample_rule_based_segment' + def test_not_called_on_till(self, mocker): """Test that sync is not called when till is less than previous changenumber""" - storage = mocker.Mock(spec=SplitStorage) + storage = mocker.Mock(spec=InMemorySplitStorage) + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorage) + + class flag_set_filter(): + def should_filter(): + return False + + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + storage.flag_set_filter.sorted_flag_sets = [] def change_number_mock(): return 2 storage.get_change_number.side_effect = change_number_mock + rbs_storage.get_change_number.side_effect = change_number_mock def get_changes(*args, **kwargs): get_changes.called += 1 @@ -123,8 +297,7 @@ def get_changes(*args, **kwargs): api = mocker.Mock() api.fetch_splits.side_effect = get_changes - from splitio.sync.split import SplitSynchronizer - split_synchronizer = SplitSynchronizer(api, storage) + split_synchronizer = SplitSynchronizer(api, storage, rbs_storage) split_synchronizer.synchronize_splits(1) assert get_changes.called == 0 @@ -132,9 +305,9 @@ def get_changes(*args, **kwargs): def test_synchronize_splits_cdn(self, mocker): """Test split sync with bypassing cdn.""" mocker.patch('splitio.sync.split._ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES', new=3) - from splitio.sync.split import SplitSynchronizer - storage = mocker.Mock(spec=SplitStorage) + storage = mocker.Mock(spec=InMemorySplitStorage) + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorage) def change_number_mock(): change_number_mock._calls += 1 @@ -145,73 +318,1098 @@ def change_number_mock(): elif change_number_mock._calls <= 7: return 1234 return 12345 # Return proper cn for CDN Bypass + + def rbs_change_number_mock(): + rbs_change_number_mock._calls += 1 + if rbs_change_number_mock._calls == 1: + return -1 + elif change_number_mock._calls >= 2 and change_number_mock._calls <= 3: + return 555 + elif change_number_mock._calls <= 9: + return 555 + return 666 # Return proper cn for CDN Bypass + change_number_mock._calls = 0 + rbs_change_number_mock._calls = 0 storage.get_change_number.side_effect = change_number_mock + rbs_storage.get_change_number.side_effect = rbs_change_number_mock api = mocker.Mock() - splits = [{ - 'changeNumber': 123, - 'trafficTypeName': 'user', - 'name': 'some_name', - 'trafficAllocation': 100, - 'trafficAllocationSeed': 123456, - 'seed': 321654, - 'status': 'ACTIVE', - 'killed': False, - 'defaultTreatment': 'off', - 'algo': 2, - 'conditions': [ - { - 'partitions': [ - {'treatment': 'on', 'size': 50}, - {'treatment': 'off', 'size': 50} - ], - 'contitionType': 'WHITELIST', - 'label': 'some_label', - 'matcherGroup': { - 'matchers': [ - { - 'matcherType': 'WHITELIST', - 'whitelistMatcherData': { - 'whitelist': ['k1', 'k2', 'k3'] - }, - 'negate': False, - } - ], - 'combiner': 'AND' - } - } - ] - }] - + rbs_1 = copy.deepcopy(json_body['rbs']['d']) def get_changes(*args, **kwargs): get_changes.called += 1 if get_changes.called == 1: - return { 'splits': splits, 'since': -1, 'till': 123 } + return { 'ff': { 'd': self.splits, 's': -1, 't': 123 }, + 'rbs': {"t": 555, "s": -1, "d": rbs_1}} elif get_changes.called == 2: - return { 'splits': [], 'since': 123, 'till': 123 } + return { 'ff': { 'd': [], 's': 123, 't': 123 }, + 'rbs': {"t": 555, "s": 555, "d": []}} elif get_changes.called == 3: - return { 'splits': [], 'since': 123, 'till': 1234 } + return { 'ff': { 'd': [], 's': 123, 't': 1234 }, + 'rbs': {"t": 555, "s": 555, "d": []}} elif get_changes.called >= 4 and get_changes.called <= 6: - return { 'splits': [], 'since': 1234, 'till': 1234 } + return { 'ff': { 'd': [], 's': 1234, 't': 1234 }, + 'rbs': {"t": 555, "s": 555, "d": []}} elif get_changes.called == 7: - return { 'splits': [], 'since': 1234, 'till': 12345 } - return { 'splits': [], 'since': 12345, 'till': 12345 } + return { 'ff': { 'd': [], 's': 1234, 't': 12345 }, + 'rbs': {"t": 555, "s": 555, "d": []}} + elif get_changes.called == 8: + return { 'ff': { 'd': [], 's': 12345, 't': 12345 }, + 'rbs': {"t": 555, "s": 555, "d": []}} + rbs_1[0]['excluded']['keys'] = ['bilal@split.io'] + return { 'ff': { 'd': [], 's': 12345, 't': 12345 }, + 'rbs': {"t": 666, "s": 666, "d": rbs_1}} + get_changes.called = 0 api.fetch_splits.side_effect = get_changes - split_synchronizer = SplitSynchronizer(api, storage) + class flag_set_filter(): + def should_filter(): + return False + + def intersect(sets): + return True + + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + storage.flag_set_filter.sorted_flag_sets = [] + + split_synchronizer = SplitSynchronizer(api, storage, rbs_storage) split_synchronizer._backoff = Backoff(1, 1) split_synchronizer.synchronize_splits() - assert mocker.call(-1, FetchOptions(True)) in api.fetch_splits.mock_calls - assert mocker.call(123, FetchOptions(True)) in api.fetch_splits.mock_calls + assert api.fetch_splits.mock_calls[0][1][0] == -1 + assert api.fetch_splits.mock_calls[0][1][2].cache_control_headers == True + assert api.fetch_splits.mock_calls[1][1][0] == 123 + assert api.fetch_splits.mock_calls[1][1][2].cache_control_headers == True split_synchronizer._backoff = Backoff(1, 0.1) split_synchronizer.synchronize_splits(12345) - assert mocker.call(12345, FetchOptions(True, 1234)) in api.fetch_splits.mock_calls + assert api.fetch_splits.mock_calls[3][1][0] == 1234 + assert api.fetch_splits.mock_calls[3][1][2].cache_control_headers == True assert len(api.fetch_splits.mock_calls) == 8 # 2 ok + BACKOFF(2 since==till + 2 re-attempts) + CDN(2 since==till) - inserted_split = storage.put.mock_calls[0][1][0] + inserted_split = storage.update.mock_calls[0][1][0][0] + assert isinstance(inserted_split, Split) + assert inserted_split.name == 'some_name' + inserted_rbs = rbs_storage.update.mock_calls[0][1][0][0] + assert inserted_rbs.excluded.get_excluded_keys() == ["mauro@split.io","gaston@split.io"] + + split_synchronizer._backoff = Backoff(1, 0.1) + split_synchronizer.synchronize_splits(None, 666) + inserted_rbs = rbs_storage.update.mock_calls[8][1][0][0] + assert inserted_rbs.excluded.get_excluded_keys() == ['bilal@split.io'] + + def test_sync_flag_sets_with_config_sets(self, mocker): + """Test split sync with flag sets.""" + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue, ['set1', 'set2']) + events_queue = queue.Queue() + rbs_storage = InMemoryRuleBasedSegmentStorage(events_queue) + + split = copy.deepcopy(self.splits[0]) + split['name'] = 'second' + splits1 = [self.splits[0].copy(), split] + splits2 = copy.deepcopy(self.splits) + splits3 = copy.deepcopy(self.splits) + splits4 = copy.deepcopy(self.splits) + api = mocker.Mock() + def get_changes(*args, **kwargs): + get_changes.called += 1 + if get_changes.called == 1: + return { 'ff': { 'd': splits1, 's': 123, 't': 123 }, + 'rbs': {'t': 123, 's': 123, 'd': []}} + elif get_changes.called == 2: + splits2[0]['sets'] = ['set3'] + return { 'ff': { 'd': splits2, 's': 124, 't': 124 }, + 'rbs': {'t': 124, 's': 124, 'd': []}} + elif get_changes.called == 3: + splits3[0]['sets'] = ['set1'] + return { 'ff': { 'd': splits3, 's': 12434, 't': 12434 }, + 'rbs': {'t': 12434, 's': 12434, 'd': []}} + splits4[0]['sets'] = ['set6'] + splits4[0]['name'] = 'new_split' + return { 'ff': { 'd': splits4, 's': 12438, 't': 12438 }, + 'rbs': {'t': 12438, 's': 12438, 'd': []}} + get_changes.called = 0 + api.fetch_splits.side_effect = get_changes + + split_synchronizer = SplitSynchronizer(api, storage, rbs_storage) + split_synchronizer._backoff = Backoff(1, 1) + split_synchronizer.synchronize_splits() + assert isinstance(storage.get('some_name'), Split) + + split_synchronizer.synchronize_splits(124) + assert storage.get('some_name') == None + + split_synchronizer.synchronize_splits(12434) + assert isinstance(storage.get('some_name'), Split) + + split_synchronizer.synchronize_splits(12438) + assert storage.get('new_name') == None + + def test_sync_flag_sets_without_config_sets(self, mocker): + """Test split sync with flag sets.""" + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorage(events_queue) + split = copy.deepcopy(self.splits[0]) + split['name'] = 'second' + splits1 = [self.splits[0].copy(), split] + splits2 = copy.deepcopy(self.splits) + splits3 = copy.deepcopy(self.splits) + splits4 = copy.deepcopy(self.splits) + api = mocker.Mock() + def get_changes(*args, **kwargs): + get_changes.called += 1 + if get_changes.called == 1: + return { 'ff': { 'd': splits1, 's': 123, 't': 123 }, + 'rbs': {"t": 123, "s": 123, "d": []}} + elif get_changes.called == 2: + splits2[0]['sets'] = ['set3'] + return { 'ff': { 'd': splits2, 's': 124, 't': 124 }, + 'rbs': {"t": 124, "s": 124, "d": []}} + elif get_changes.called == 3: + splits3[0]['sets'] = ['set1'] + return { 'ff': { 'd': splits3, 's': 12434, 't': 12434 }, + 'rbs': {"t": 12434, "s": 12434, "d": []}} + splits4[0]['sets'] = ['set6'] + splits4[0]['name'] = 'third_split' + return { 'ff': { 'd': splits4, 's': 12438, 't': 12438 }, + 'rbs': {"t": 12438, "s": 12438, "d": []}} + get_changes.called = 0 + api.fetch_splits.side_effect = get_changes + + split_synchronizer = SplitSynchronizer(api, storage, rbs_storage) + split_synchronizer._backoff = Backoff(1, 1) + split_synchronizer.synchronize_splits() + assert isinstance(storage.get('some_name'), Split) + + split_synchronizer.synchronize_splits(124) + assert isinstance(storage.get('some_name'), Split) + + split_synchronizer.synchronize_splits(12434) + assert isinstance(storage.get('some_name'), Split) + + split_synchronizer.synchronize_splits(12438) + assert isinstance(storage.get('third_split'), Split) + +class SplitsSynchronizerAsyncTests(object): + """Split synchronizer test cases.""" + + splits = copy.deepcopy(splits_raw) + + @pytest.mark.asyncio + async def test_synchronize_splits_error(self, mocker): + """Test that if fetching splits fails at some_point, the task will continue running.""" + storage = mocker.Mock(spec=InMemorySplitStorageAsync) + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorageAsync) + api = mocker.Mock() + + async def run(x, y, c): + raise APIException("something broke") + run._calls = 0 + api.fetch_splits = run + + async def get_change_number(*args): + return -1 + storage.get_change_number = get_change_number + rbs_storage.get_change_number = get_change_number + + class flag_set_filter(): + def should_filter(): + return False + + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + storage.flag_set_filter.sorted_flag_sets = [] + + split_synchronizer = SplitSynchronizerAsync(api, storage, rbs_storage) + + with pytest.raises(APIException): + await split_synchronizer.synchronize_splits(1) + + @pytest.mark.asyncio + async def test_synchronize_splits(self, mocker): + """Test split sync.""" + storage = mocker.Mock(spec=InMemorySplitStorageAsync) + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorageAsync) + + async def change_number_mock(): + change_number_mock._calls += 1 + if change_number_mock._calls == 1: + return -1 + return 123 + async def rbs_change_number_mock(): + rbs_change_number_mock._calls += 1 + if rbs_change_number_mock._calls == 1: + return -1 + return 123 + + change_number_mock._calls = 0 + rbs_change_number_mock._calls = 0 + storage.get_change_number = change_number_mock + rbs_storage.get_change_number.side_effect = rbs_change_number_mock + + class flag_set_filter(): + def should_filter(): + return False + + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + storage.flag_set_filter.sorted_flag_sets = [] + + self.parsed_split = None + async def update(parsed_split, deleted, chanhe_number): + if len(parsed_split) > 0: + self.parsed_split = parsed_split + storage.update = update + + self.parsed_rbs = None + async def update(parsed_rbs, deleted, chanhe_number): + if len(parsed_rbs) > 0: + self.parsed_rbs = parsed_rbs + rbs_storage.update = update + + self.clear = False + async def clear(): + self.clear = True + storage.clear = clear + + self.clear2 = False + async def clear(): + self.clear2 = True + rbs_storage.clear = clear + + api = mocker.Mock() + self.change_number_1 = None + self.fetch_options_1 = None + self.change_number_2 = None + self.fetch_options_2 = None + async def get_changes(change_number, rbs_change_number, fetch_options): + get_changes.called += 1 + if get_changes.called == 1: + self.change_number_1 = change_number + self.fetch_options_1 = fetch_options + return json_body + else: + self.change_number_2 = change_number + self.fetch_options_2 = fetch_options + return { + "ff": { + "t":123, + "s":123, + 'd': [] + }, + "rbs": { + "t": 123, + "s": 123, + "d": [] + } + } + get_changes.called = 0 + api.fetch_splits = get_changes + api.clear_storage.return_value = False + + split_synchronizer = SplitSynchronizerAsync(api, storage, rbs_storage) + await split_synchronizer.synchronize_splits() + + assert (-1, FetchOptions(True)._cache_control_headers) == (self.change_number_1, self.fetch_options_1._cache_control_headers) + assert (123, FetchOptions(True)._cache_control_headers) == (self.change_number_2, self.fetch_options_2._cache_control_headers) + inserted_split = self.parsed_split[0] + assert isinstance(inserted_split, Split) + assert inserted_split.name == 'some_name' + + inserted_rbs = self.parsed_rbs[0] + assert isinstance(inserted_rbs, RuleBasedSegment) + assert inserted_rbs.name == 'sample_rule_based_segment' + + + @pytest.mark.asyncio + async def test_not_called_on_till(self, mocker): + """Test that sync is not called when till is less than previous changenumber""" + storage = mocker.Mock(spec=InMemorySplitStorageAsync) + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorageAsync) + + class flag_set_filter(): + def should_filter(): + return False + + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + storage.flag_set_filter.sorted_flag_sets = [] + + async def change_number_mock(): + return 2 + storage.get_change_number = change_number_mock + rbs_storage.get_change_number.side_effect = change_number_mock + + async def get_changes(*args, **kwargs): + get_changes.called += 1 + return None + get_changes.called = 0 + api = mocker.Mock() + api.fetch_splits = get_changes + + split_synchronizer = SplitSynchronizerAsync(api, storage, rbs_storage) + await split_synchronizer.synchronize_splits(1) + assert get_changes.called == 0 + + @pytest.mark.asyncio + async def test_synchronize_splits_cdn(self, mocker): + """Test split sync with bypassing cdn.""" + mocker.patch('splitio.sync.split._ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES', new=3) + storage = mocker.Mock(spec=InMemorySplitStorageAsync) + rbs_storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorageAsync) + async def change_number_mock(): + change_number_mock._calls += 1 + if change_number_mock._calls == 1: + return -1 + elif change_number_mock._calls >= 2 and change_number_mock._calls <= 3: + return 123 + elif change_number_mock._calls <= 7: + return 1234 + return 12345 # Return proper cn for CDN Bypass + async def rbs_change_number_mock(): + rbs_change_number_mock._calls += 1 + if rbs_change_number_mock._calls == 1: + return -1 + elif change_number_mock._calls >= 2 and change_number_mock._calls <= 3: + return 555 + elif change_number_mock._calls <= 9: + return 555 + return 666 # Return proper cn for CDN Bypass + + change_number_mock._calls = 0 + rbs_change_number_mock._calls = 0 + storage.get_change_number = change_number_mock + rbs_storage.get_change_number = rbs_change_number_mock + + self.parsed_split = None + async def update(parsed_split, deleted, change_number): + if len(parsed_split) > 0: + self.parsed_split = parsed_split + storage.update = update + + self.parsed_rbs = None + async def rbs_update(parsed, deleted, change_number): + if len(parsed) > 0: + self.parsed_rbs = parsed + rbs_storage.update = rbs_update + + api = mocker.Mock() + self.change_number_1 = None + self.fetch_options_1 = None + self.change_number_2 = None + self.fetch_options_2 = None + self.change_number_3 = None + self.fetch_options_3 = None + rbs_1 = copy.deepcopy(json_body['rbs']['d']) + + async def get_changes(change_number, rbs_change_number, fetch_options): + get_changes.called += 1 + if get_changes.called == 1: + self.change_number_1 = change_number + self.fetch_options_1 = fetch_options + return { 'ff': { 'd': self.splits, 's': -1, 't': 123 }, + 'rbs': {"t": 555, "s": -1, "d": rbs_1}} + elif get_changes.called == 2: + self.change_number_2 = change_number + self.fetch_options_2 = fetch_options + return { 'ff': { 'd': [], 's': 123, 't': 123 }, + 'rbs': {"t": 555, "s": 555, "d": []}} + elif get_changes.called == 3: + return { 'ff': { 'd': [], 's': 123, 't': 1234 }, + 'rbs': {"t": 555, "s": 555, "d": []}} + elif get_changes.called >= 4 and get_changes.called <= 6: + return { 'ff': { 'd': [], 's': 1234, 't': 1234 }, + 'rbs': {"t": 555, "s": 555, "d": []}} + elif get_changes.called == 7: + return { 'ff': { 'd': [], 's': 1234, 't': 12345 }, + 'rbs': {"t": 555, "s": 555, "d": []}} + elif get_changes.called == 8: + self.change_number_3 = change_number + self.fetch_options_3 = fetch_options + return { 'ff': { 'd': [], 's': 12345, 't': 12345 }, + 'rbs': {"t": 555, "s": 555, "d": []}} + rbs_1[0]['excluded']['keys'] = ['bilal@split.io'] + return { 'ff': { 'd': [], 's': 12345, 't': 12345 }, + 'rbs': {"t": 666, "s": 666, "d": rbs_1}} + + get_changes.called = 0 + api.fetch_splits = get_changes + + class flag_set_filter(): + def should_filter(): + return False + + def intersect(sets): + return True + + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + storage.flag_set_filter.sorted_flag_sets = [] + + self.clear = False + async def clear(): + self.clear = True + storage.clear = clear + + self.clear2 = False + async def clear(): + self.clear2 = True + rbs_storage.clear = clear + + split_synchronizer = SplitSynchronizerAsync(api, storage, rbs_storage) + split_synchronizer._backoff = Backoff(1, 1) + await split_synchronizer.synchronize_splits() + + assert (-1, FetchOptions(True).cache_control_headers) == (self.change_number_1, self.fetch_options_1.cache_control_headers) + assert (123, FetchOptions(True).cache_control_headers) == (self.change_number_2, self.fetch_options_2.cache_control_headers) + + split_synchronizer._backoff = Backoff(1, 0.1) + await split_synchronizer.synchronize_splits(12345) + assert (12345, True, 1234) == (self.change_number_3, self.fetch_options_3.cache_control_headers, self.fetch_options_3.change_number) + assert get_changes.called == 8 # 2 ok + BACKOFF(2 since==till + 2 re-attempts) + CDN(2 since==till) + + inserted_split = self.parsed_split[0] assert isinstance(inserted_split, Split) assert inserted_split.name == 'some_name' + inserted_rbs = self.parsed_rbs[0] + assert inserted_rbs.excluded.get_excluded_keys() == ["mauro@split.io","gaston@split.io"] + + split_synchronizer._backoff = Backoff(1, 0.1) + await split_synchronizer.synchronize_splits(None, 666) + inserted_rbs = self.parsed_rbs[0] + assert inserted_rbs.excluded.get_excluded_keys() == ['bilal@split.io'] + + @pytest.mark.asyncio + async def test_sync_flag_sets_with_config_sets(self, mocker): + """Test split sync with flag sets.""" + internal_events_queue = asyncio.Queue() + storage = InMemorySplitStorageAsync(internal_events_queue, ['set1', 'set2']) + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + + split = self.splits[0].copy() + split['name'] = 'second' + splits1 = [self.splits[0].copy(), split] + splits2 = self.splits.copy() + splits3 = self.splits.copy() + splits4 = self.splits.copy() + api = mocker.Mock() + async def get_changes(*args, **kwargs): + get_changes.called += 1 + if get_changes.called == 1: + return { 'ff': { 'd': splits1, 's': 123, 't': 123 }, + 'rbs': {'t': 123, 's': 123, 'd': []}} + elif get_changes.called == 2: + splits2[0]['sets'] = ['set3'] + return { 'ff': { 'd': splits2, 's': 124, 't': 124 }, + 'rbs': {'t': 124, 's': 124, 'd': []}} + elif get_changes.called == 3: + splits3[0]['sets'] = ['set1'] + return { 'ff': { 'd': splits3, 's': 12434, 't': 12434 }, + 'rbs': {'t': 12434, 's': 12434, 'd': []}} + splits4[0]['sets'] = ['set6'] + splits4[0]['name'] = 'new_split' + return { 'ff': { 'd': splits4, 's': 12438, 't': 12438 }, + 'rbs': {'t': 12438, 's': 12438, 'd': []}} + + get_changes.called = 0 + api.fetch_splits = get_changes + + split_synchronizer = SplitSynchronizerAsync(api, storage, rbs_storage) + split_synchronizer._backoff = Backoff(1, 1) + await split_synchronizer.synchronize_splits() + assert isinstance(await storage.get('some_name'), Split) + + await split_synchronizer.synchronize_splits(124) + assert await storage.get('some_name') == None + + await split_synchronizer.synchronize_splits(12434) + assert isinstance(await storage.get('some_name'), Split) + + await split_synchronizer.synchronize_splits(12438) + assert await storage.get('new_name') == None + + @pytest.mark.asyncio + async def test_sync_flag_sets_without_config_sets(self, mocker): + """Test split sync with flag sets.""" + internal_events_queue = asyncio.Queue() + storage = InMemorySplitStorageAsync(internal_events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + split = self.splits[0].copy() + split['name'] = 'second' + splits1 = [self.splits[0].copy(), split] + splits2 = self.splits.copy() + splits3 = self.splits.copy() + splits4 = self.splits.copy() + api = mocker.Mock() + async def get_changes(*args, **kwargs): + get_changes.called += 1 + if get_changes.called == 1: + return { 'ff': { 'd': splits1, 's': 123, 't': 123 }, + 'rbs': {"t": 123, "s": 123, "d": []}} + elif get_changes.called == 2: + splits2[0]['sets'] = ['set3'] + return { 'ff': { 'd': splits2, 's': 124, 't': 124 }, + 'rbs': {"t": 124, "s": 124, "d": []}} + elif get_changes.called == 3: + splits3[0]['sets'] = ['set1'] + return { 'ff': { 'd': splits3, 's': 12434, 't': 12434 }, + 'rbs': {"t": 12434, "s": 12434, "d": []}} + splits4[0]['sets'] = ['set6'] + splits4[0]['name'] = 'third_split' + return { 'ff': { 'd': splits4, 's': 12438, 't': 12438 }, + 'rbs': {"t": 12438, "s": 12438, "d": []}} + get_changes.called = 0 + api.fetch_splits.side_effect = get_changes + + split_synchronizer = SplitSynchronizerAsync(api, storage, rbs_storage) + split_synchronizer._backoff = Backoff(1, 1) + await split_synchronizer.synchronize_splits() + assert isinstance(await storage.get('new_split'), Split) + + await split_synchronizer.synchronize_splits(124) + assert isinstance(await storage.get('new_split'), Split) + + await split_synchronizer.synchronize_splits(12434) + assert isinstance(await storage.get('new_split'), Split) + + await split_synchronizer.synchronize_splits(12438) + assert isinstance(await storage.get('third_split'), Split) + +class LocalSplitsSynchronizerTests(object): + """Split synchronizer test cases.""" + + payload = copy.deepcopy(json_body) + + def test_synchronize_splits_error(self, mocker): + """Test that if fetching splits fails at some_point, the task will continue running.""" + storage = mocker.Mock(spec=SplitStorage) + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + split_synchronizer = LocalSplitSynchronizer("/incorrect_file", storage, rbs_storage) + + with pytest.raises(Exception): + split_synchronizer.synchronize_splits(1) + + def test_synchronize_splits(self, mocker): + """Test split sync.""" + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorage(events_queue) + + def read_splits_from_json_file(*args, **kwargs): + return self.payload + + split_synchronizer = LocalSplitSynchronizer("split.json", storage, rbs_storage, LocalhostMode.JSON) + split_synchronizer._read_feature_flags_from_json_file = read_splits_from_json_file + + split_synchronizer.synchronize_splits() + inserted_split = storage.get(self.payload["ff"]["d"][0]['name']) + assert isinstance(inserted_split, Split) + assert inserted_split.name == 'some_name' + + # Should sync when changenumber is not changed + self.payload["ff"]["d"][0]['killed'] = True + split_synchronizer.synchronize_splits() + inserted_split = storage.get(self.payload["ff"]["d"][0]['name']) + assert inserted_split.killed + + # Should not sync when changenumber is less than stored + self.payload["ff"]["t"] = 122 + self.payload["ff"]["d"][0]['killed'] = False + split_synchronizer.synchronize_splits() + inserted_split = storage.get(self.payload["ff"]["d"][0]['name']) + assert inserted_split.killed + + # Should sync when changenumber is higher than stored + self.payload["ff"]["t"] = 1675095324999 + split_synchronizer._current_json_sha = "-1" + split_synchronizer.synchronize_splits() + inserted_split = storage.get(self.payload["ff"]["d"][0]['name']) + assert inserted_split.killed == False + + # Should sync when till is default (-1) + self.payload["ff"]["t"] = -1 + split_synchronizer._current_json_sha = "-1" + self.payload["ff"]["d"][0]['killed'] = True + split_synchronizer.synchronize_splits() + inserted_split = storage.get(self.payload["ff"]["d"][0]['name']) + assert inserted_split.killed == True + + def test_sync_flag_sets_with_config_sets(self, mocker): + """Test split sync with flag sets.""" + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue, ['set1', 'set2']) + rbs_storage = InMemoryRuleBasedSegmentStorage(events_queue) + + split = self.payload["ff"]["d"][0].copy() + split['name'] = 'second' + splits1 = [self.payload["ff"]["d"][0].copy(), split] + splits2 = self.payload["ff"]["d"].copy() + splits3 = self.payload["ff"]["d"].copy() + splits4 = self.payload["ff"]["d"].copy() + + self.called = 0 + def read_feature_flags_from_json_file(*args, **kwargs): + self.called += 1 + if self.called == 1: + return {"ff": {"d": splits1, "t": 123, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + elif self.called == 2: + splits2[0]['sets'] = ['set3'] + return {"ff": {"d": splits2, "t": 124, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + elif self.called == 3: + splits3[0]['sets'] = ['set1'] + return {"ff": {"d": splits3, "t": 12434, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + splits4[0]['sets'] = ['set6'] + splits4[0]['name'] = 'new_split' + return {"ff": {"d": splits4, "t": 12438, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + + split_synchronizer = LocalSplitSynchronizer("split.json", storage, rbs_storage, LocalhostMode.JSON) + split_synchronizer._read_feature_flags_from_json_file = read_feature_flags_from_json_file + + split_synchronizer.synchronize_splits() + assert isinstance(storage.get('some_name'), Split) + + split_synchronizer.synchronize_splits(124) + assert storage.get('some_name') == None + + split_synchronizer.synchronize_splits(12434) + assert isinstance(storage.get('some_name'), Split) + + split_synchronizer.synchronize_splits(12438) + assert storage.get('new_name') == None + + def test_sync_flag_sets_without_config_sets(self, mocker): + """Test split sync with flag sets.""" + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorage(events_queue) + + split = self.payload["ff"]["d"][0].copy() + split['name'] = 'second' + splits1 = [self.payload["ff"]["d"][0].copy(), split] + splits2 = self.payload["ff"]["d"].copy() + splits3 = self.payload["ff"]["d"].copy() + splits4 = self.payload["ff"]["d"].copy() + + self.called = 0 + def read_feature_flags_from_json_file(*args, **kwargs): + self.called += 1 + if self.called == 1: + return {"ff": {"d": splits1, "t": 123, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + elif self.called == 2: + splits2[0]['sets'] = ['set3'] + return {"ff": {"d": splits2, "t": 124, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + elif self.called == 3: + splits3[0]['sets'] = ['set1'] + return {"ff": {"d": splits3, "t": 12434, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + splits4[0]['sets'] = ['set6'] + splits4[0]['name'] = 'third_split' + return {"ff": {"d": splits4, "t": 12438, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + + split_synchronizer = LocalSplitSynchronizer("split.json", storage, rbs_storage, LocalhostMode.JSON) + split_synchronizer._read_feature_flags_from_json_file = read_feature_flags_from_json_file + + split_synchronizer.synchronize_splits() + assert isinstance(storage.get('new_split'), Split) + + split_synchronizer.synchronize_splits(124) + assert isinstance(storage.get('new_split'), Split) + + split_synchronizer.synchronize_splits(12434) + assert isinstance(storage.get('new_split'), Split) + + split_synchronizer.synchronize_splits(12438) + assert isinstance(storage.get('third_split'), Split) + + def test_reading_json(self, mocker): + """Test reading json file.""" + f = open("./splits.json", "w") + f.write(json.dumps(self.payload)) + f.close() + events_queue = queue.Queue() + storage = InMemorySplitStorage(events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorage(events_queue) + split_synchronizer = LocalSplitSynchronizer("./splits.json", storage, rbs_storage, LocalhostMode.JSON) + split_synchronizer.synchronize_splits() + + inserted_split = storage.get(self.payload['ff']['d'][0]['name']) + assert isinstance(inserted_split, Split) + assert inserted_split.name == self.payload['ff']['d'][0]['name'] + + inserted_rbs = rbs_storage.get(self.payload['rbs']['d'][0]['name']) + assert isinstance(inserted_rbs, RuleBasedSegment) + assert inserted_rbs.name == self.payload['rbs']['d'][0]['name'] + + os.remove("./splits.json") + + def test_json_elements_sanitization(self, mocker): + """Test sanitization.""" + split_synchronizer = LocalSplitSynchronizer(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + + # check no changes if all elements exist with valid values + parsed = {"ff": {"d": [], "s": -1, "t": -1}, "rbs": {"d": [], "s": -1, "t": -1}} + assert (split_synchronizer._sanitize_json_elements(parsed) == parsed) + + # check set since to -1 when is None + parsed2 = parsed.copy() + parsed2['ff']['s'] = None + assert (split_synchronizer._sanitize_json_elements(parsed2) == parsed) + + # check no changes if since > -1 + parsed2 = parsed.copy() + parsed2['ff']['s'] = 12 + assert (split_synchronizer._sanitize_json_elements(parsed2) == parsed) + + # check set till to -1 when is None + parsed2 = parsed.copy() + parsed2['ff']['t'] = None + assert (split_synchronizer._sanitize_json_elements(parsed2) == parsed) + + # check add since when missing + parsed2 = {"ff": {"d": [], "t": -1}, "rbs": {"d": [], "s": -1, "t": -1}} + assert (split_synchronizer._sanitize_json_elements(parsed2) == parsed) + + # check add till when missing + parsed2 = {"ff": {"d": [], "s": -1}, "rbs": {"d": [], "s": -1, "t": -1}} + assert (split_synchronizer._sanitize_json_elements(parsed2) == parsed) + + # check add splits when missing + parsed2 = {"ff": {"s": -1, "t": -1}, "rbs": {"d": [], "s": -1, "t": -1}} + assert (split_synchronizer._sanitize_json_elements(parsed2) == parsed) + + # check add since when missing + parsed2 = {"ff": {"d": [], "t": -1}, "rbs": {"d": [], "t": -1}} + assert (split_synchronizer._sanitize_json_elements(parsed2) == parsed) + + # check add till when missing + parsed2 = {"ff": {"d": [], "s": -1}, "rbs": {"d": [], "s": -1}} + assert (split_synchronizer._sanitize_json_elements(parsed2) == parsed) + + # check add splits when missing + parsed2 = {"ff": {"s": -1, "t": -1}, "rbs": {"s": -1, "t": -1}} + assert (split_synchronizer._sanitize_json_elements(parsed2) == parsed) + + def test_elements_sanitization(self, mocker): + """Test sanitization.""" + split_synchronizer = LocalSplitSynchronizer(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + + # No changes when split structure is good + assert (split_synchronizer._sanitize_feature_flag_elements(splits_json["splitChange1_1"]['ff']['d']) == splits_json["splitChange1_1"]['ff']['d']) + + # test 'trafficTypeName' value None + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['trafficTypeName'] = None + assert (split_synchronizer._sanitize_feature_flag_elements(split) == splits_json["splitChange1_1"]['ff']['d']) + + # test 'trafficAllocation' value None + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['trafficAllocation'] = None + assert (split_synchronizer._sanitize_feature_flag_elements(split) == splits_json["splitChange1_1"]['ff']['d']) + + # test 'trafficAllocation' valid value should not change + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['trafficAllocation'] = 50 + assert (split_synchronizer._sanitize_feature_flag_elements(split) == split) + + # test 'trafficAllocation' invalid value should change + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['trafficAllocation'] = 110 + assert (split_synchronizer._sanitize_feature_flag_elements(split) == splits_json["splitChange1_1"]['ff']['d']) + + # test 'trafficAllocationSeed' is set to millisec epoch when None + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['trafficAllocationSeed'] = None + assert (split_synchronizer._sanitize_feature_flag_elements(split)[0]['trafficAllocationSeed'] > 0) + + # test 'trafficAllocationSeed' is set to millisec epoch when 0 + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['trafficAllocationSeed'] = 0 + assert (split_synchronizer._sanitize_feature_flag_elements(split)[0]['trafficAllocationSeed'] > 0) + + # test 'seed' is set to millisec epoch when None + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['seed'] = None + assert (split_synchronizer._sanitize_feature_flag_elements(split)[0]['seed'] > 0) + + # test 'seed' is set to millisec epoch when its 0 + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['seed'] = 0 + assert (split_synchronizer._sanitize_feature_flag_elements(split)[0]['seed'] > 0) + + # test 'status' is set to ACTIVE when None + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['status'] = None + assert (split_synchronizer._sanitize_feature_flag_elements(split) == splits_json["splitChange1_1"]['ff']['d']) + + # test 'status' is set to ACTIVE when incorrect + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['status'] = 'ww' + assert (split_synchronizer._sanitize_feature_flag_elements(split) == splits_json["splitChange1_1"]['ff']['d']) + + # test ''killed' is set to False when incorrect + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['killed'] = None + assert (split_synchronizer._sanitize_feature_flag_elements(split) == splits_json["splitChange1_1"]['ff']['d']) + + # test 'defaultTreatment' is set to on when None + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['defaultTreatment'] = None + assert (split_synchronizer._sanitize_feature_flag_elements(split)[0]['defaultTreatment'] == 'control') + + # test 'defaultTreatment' is set to on when its empty + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['defaultTreatment'] = ' ' + assert (split_synchronizer._sanitize_feature_flag_elements(split)[0]['defaultTreatment'] == 'control') + + # test 'changeNumber' is set to 0 when None + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['changeNumber'] = None + assert (split_synchronizer._sanitize_feature_flag_elements(split)[0]['changeNumber'] == 0) + + # test 'changeNumber' is set to 0 when invalid + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['changeNumber'] = -33 + assert (split_synchronizer._sanitize_feature_flag_elements(split)[0]['changeNumber'] == 0) + + # test 'algo' is set to 2 when None + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['algo'] = None + assert (split_synchronizer._sanitize_feature_flag_elements(split)[0]['algo'] == 2) + + # test 'algo' is set to 2 when higher than 2 + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['algo'] = 3 + assert (split_synchronizer._sanitize_feature_flag_elements(split)[0]['algo'] == 2) + + # test 'algo' is set to 2 when lower than 2 + split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]['algo'] = 1 + assert (split_synchronizer._sanitize_feature_flag_elements(split)[0]['algo'] == 2) + + split = splits_json["splitChange1_1"]['ff']['d'].copy() + del split[0]['prerequisites'] + assert (split_synchronizer._sanitize_feature_flag_elements(split)[0]['prerequisites'] == []) + + # test 'status' is set to ACTIVE when None + rbs = copy.deepcopy(json_body["rbs"]["d"]) + rbs[0]['status'] = None + assert (split_synchronizer._sanitize_rb_segment_elements(rbs)[0]['status'] == 'ACTIVE') + + # test 'changeNumber' is set to 0 when invalid + rbs = copy.deepcopy(json_body["rbs"]["d"]) + rbs[0]['changeNumber'] = -2 + assert (split_synchronizer._sanitize_rb_segment_elements(rbs)[0]['changeNumber'] == 0) + + rbs = copy.deepcopy(json_body["rbs"]["d"]) + del rbs[0]['conditions'] + assert (len(split_synchronizer._sanitize_rb_segment_elements(rbs)[0]['conditions']) == 1) + + def test_condition_sanitization(self, mocker): + """Test sanitization.""" + split_synchronizer = LocalSplitSynchronizer(mocker.Mock(), mocker.Mock(), mocker.Mock()) + + # test missing all conditions with default rule set to 100% off + split = splits_json["splitChange1_1"]['ff']['d'].copy() + target_split = splits_json["splitChange1_1"]['ff']['d'].copy() + target_split[0]["conditions"][0]['partitions'][0]['size'] = 0 + target_split[0]["conditions"][0]['partitions'][1]['size'] = 100 + del split[0]["conditions"] + assert (split_synchronizer._sanitize_feature_flag_elements(split) == target_split) + + # test missing ALL_KEYS condition matcher with default rule set to 100% off + split = splits_json["splitChange1_1"]['ff']['d'].copy() + target_split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]["conditions"][0]["matcherGroup"]["matchers"][0]["matcherType"] = "IN_STR" + target_split = split.copy() + target_split[0]["conditions"].append(splits_json["splitChange1_1"]['ff']['d'][0]["conditions"][0]) + target_split[0]["conditions"][1]['partitions'][0]['size'] = 0 + target_split[0]["conditions"][1]['partitions'][1]['size'] = 100 + assert (split_synchronizer._sanitize_feature_flag_elements(split) == target_split) + + # test missing ROLLOUT condition type with default rule set to 100% off + split = splits_json["splitChange1_1"]['ff']['d'].copy() + target_split = splits_json["splitChange1_1"]['ff']['d'].copy() + split[0]["conditions"][0]["conditionType"] = "NOT" + target_split = split.copy() + target_split[0]["conditions"].append(splits_json["splitChange1_1"]['ff']['d'][0]["conditions"][0]) + target_split[0]["conditions"][1]['partitions'][0]['size'] = 0 + target_split[0]["conditions"][1]['partitions'][1]['size'] = 100 + assert (split_synchronizer._sanitize_feature_flag_elements(split) == target_split) + +class LocalSplitsSynchronizerAsyncTests(object): + """Split synchronizer test cases.""" + + payload = copy.deepcopy(json_body) + + @pytest.mark.asyncio + async def test_synchronize_splits_error(self, mocker): + """Test that if fetching splits fails at some_point, the task will continue running.""" + storage = mocker.Mock(spec=SplitStorage) + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + split_synchronizer = LocalSplitSynchronizerAsync("/incorrect_file", storage, rbs_storage) + + with pytest.raises(Exception): + await split_synchronizer.synchronize_splits(1) + + @pytest.mark.asyncio + async def test_synchronize_splits(self, mocker): + """Test split sync.""" + internal_events_queue = asyncio.Queue() + storage = InMemorySplitStorageAsync(internal_events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + + async def read_splits_from_json_file(*args, **kwargs): + return self.payload + + split_synchronizer = LocalSplitSynchronizerAsync("split.json", storage, rbs_storage, LocalhostMode.JSON) + split_synchronizer._read_feature_flags_from_json_file = read_splits_from_json_file + + await split_synchronizer.synchronize_splits() + inserted_split = await storage.get(self.payload["ff"]["d"][0]['name']) + assert isinstance(inserted_split, Split) + assert inserted_split.name == 'some_name' + + # Should sync when changenumber is not changed + self.payload["ff"]["d"][0]['killed'] = True + await split_synchronizer.synchronize_splits() + inserted_split = await storage.get(self.payload["ff"]["d"][0]['name']) + assert inserted_split.killed + + # Should not sync when changenumber is less than stored + self.payload["ff"]["t"] = 122 + self.payload["ff"]["d"][0]['killed'] = False + await split_synchronizer.synchronize_splits() + inserted_split = await storage.get(self.payload["ff"]["d"][0]['name']) + assert inserted_split.killed + + # Should sync when changenumber is higher than stored + self.payload["ff"]["t"] = 1675095324999 + split_synchronizer._current_json_sha = "-1" + await split_synchronizer.synchronize_splits() + inserted_split = await storage.get(self.payload["ff"]["d"][0]['name']) + assert inserted_split.killed == False + + # Should sync when till is default (-1) + self.payload["ff"]["t"] = -1 + split_synchronizer._current_json_sha = "-1" + self.payload["ff"]["d"][0]['killed'] = True + await split_synchronizer.synchronize_splits() + inserted_split = await storage.get(self.payload["ff"]["d"][0]['name']) + assert inserted_split.killed == True + + @pytest.mark.asyncio + async def test_sync_flag_sets_with_config_sets(self, mocker): + """Test split sync with flag sets.""" + internal_events_queue = asyncio.Queue() + storage = InMemorySplitStorageAsync(internal_events_queue, ['set1', 'set2']) + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + + split = self.payload["ff"]["d"][0].copy() + split['name'] = 'second' + splits1 = [self.payload["ff"]["d"][0].copy(), split] + splits2 = self.payload["ff"]["d"].copy() + splits3 = self.payload["ff"]["d"].copy() + splits4 = self.payload["ff"]["d"].copy() + + self.called = 0 + async def read_feature_flags_from_json_file(*args, **kwargs): + self.called += 1 + if self.called == 1: + return {"ff": {"d": splits1, "t": 123, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + elif self.called == 2: + splits2[0]['sets'] = ['set3'] + return {"ff": {"d": splits2, "t": 124, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + elif self.called == 3: + splits3[0]['sets'] = ['set1'] + return {"ff": {"d": splits3, "t": 12434, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + splits4[0]['sets'] = ['set6'] + splits4[0]['name'] = 'new_split' + return {"ff": {"d": splits4, "t": 12438, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + + split_synchronizer = LocalSplitSynchronizerAsync("split.json", storage, rbs_storage, LocalhostMode.JSON) + split_synchronizer._read_feature_flags_from_json_file = read_feature_flags_from_json_file + + await split_synchronizer.synchronize_splits() + assert isinstance(await storage.get('some_name'), Split) + + await split_synchronizer.synchronize_splits(124) + assert await storage.get('some_name') == None + + await split_synchronizer.synchronize_splits(12434) + assert isinstance(await storage.get('some_name'), Split) + + await split_synchronizer.synchronize_splits(12438) + assert await storage.get('new_name') == None + + @pytest.mark.asyncio + async def test_sync_flag_sets_without_config_sets(self, mocker): + """Test split sync with flag sets.""" + internal_events_queue = asyncio.Queue() + storage = InMemorySplitStorageAsync(internal_events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + + split = self.payload["ff"]["d"][0].copy() + split['name'] = 'second' + splits1 = [self.payload["ff"]["d"][0].copy(), split] + splits2 = self.payload["ff"]["d"].copy() + splits3 = self.payload["ff"]["d"].copy() + splits4 = self.payload["ff"]["d"].copy() + + self.called = 0 + async def read_feature_flags_from_json_file(*args, **kwargs): + self.called += 1 + if self.called == 1: + return {"ff": {"d": splits1, "t": 123, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + elif self.called == 2: + return {"ff": {"d": splits2, "t": 124, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + elif self.called == 3: + splits3[0]['sets'] = ['set1'] + return {"ff": {"d": splits3, "t": 12434, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + splits4[0]['sets'] = ['set6'] + splits4[0]['name'] = 'third_split' + return {"ff": {"d": splits4, "t": 12438, "s": -1}, "rbs": {"d": [], "t": -1, "s": -1}} + + split_synchronizer = LocalSplitSynchronizerAsync("split.json", storage, rbs_storage, LocalhostMode.JSON) + split_synchronizer._read_feature_flags_from_json_file = read_feature_flags_from_json_file + + await split_synchronizer.synchronize_splits() + assert isinstance(await storage.get('new_split'), Split) + + await split_synchronizer.synchronize_splits(124) + assert isinstance(await storage.get('new_split'), Split) + + await split_synchronizer.synchronize_splits(12434) + assert isinstance(await storage.get('new_split'), Split) + + await split_synchronizer.synchronize_splits(12438) + assert isinstance(await storage.get('third_split'), Split) + + @pytest.mark.asyncio + async def test_reading_json(self, mocker): + """Test reading json file.""" + async with aiofiles.open("./splits.json", "w") as f: + await f.write(json.dumps(self.payload)) + internal_events_queue = asyncio.Queue() + storage = InMemorySplitStorageAsync(internal_events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + split_synchronizer = LocalSplitSynchronizerAsync("./splits.json", storage, rbs_storage, LocalhostMode.JSON) + await split_synchronizer.synchronize_splits() + + inserted_split = await storage.get(self.payload['ff']['d'][0]['name']) + assert isinstance(inserted_split, Split) + assert inserted_split.name == self.payload['ff']['d'][0]['name'] + + inserted_rbs = await rbs_storage.get(self.payload['rbs']['d'][0]['name']) + assert isinstance(inserted_rbs, RuleBasedSegment) + assert inserted_rbs.name == self.payload['rbs']['d'][0]['name'] + + os.remove("./splits.json") diff --git a/tests/sync/test_synchronizer.py b/tests/sync/test_synchronizer.py index 43377841..1244429b 100644 --- a/tests/sync/test_synchronizer.py +++ b/tests/sync/test_synchronizer.py @@ -1,97 +1,224 @@ """Synchronizer tests.""" - +import unittest.mock as mock import pytest +import queue +import asyncio -from splitio.sync.synchronizer import Synchronizer, SplitTasks, SplitSynchronizers -from splitio.tasks.split_sync import SplitSynchronizationTask -from splitio.tasks.segment_sync import SegmentSynchronizationTask -from splitio.tasks.impressions_sync import ImpressionsSyncTask, ImpressionsCountSyncTask -from splitio.tasks.events_sync import EventsSyncTask -from splitio.sync.split import SplitSynchronizer -from splitio.sync.segment import SegmentSynchronizer -from splitio.sync.impression import ImpressionSynchronizer, ImpressionsCountSynchronizer -from splitio.sync.event import EventSynchronizer -from splitio.storage import SegmentStorage, SplitStorage -from splitio.api import APIException +from splitio.sync.synchronizer import Synchronizer, SynchronizerAsync, SplitTasks, SplitSynchronizers, LocalhostSynchronizer, LocalhostSynchronizerAsync, RedisSynchronizer, RedisSynchronizerAsync +from splitio.tasks.split_sync import SplitSynchronizationTask, SplitSynchronizationTaskAsync +from splitio.tasks.unique_keys_sync import UniqueKeysSyncTask, ClearFilterSyncTask, UniqueKeysSyncTaskAsync, ClearFilterSyncTaskAsync +from splitio.tasks.segment_sync import SegmentSynchronizationTask, SegmentSynchronizationTaskAsync +from splitio.tasks.impressions_sync import ImpressionsSyncTask, ImpressionsCountSyncTask, ImpressionsCountSyncTaskAsync, ImpressionsSyncTaskAsync +from splitio.tasks.events_sync import EventsSyncTask, EventsSyncTaskAsync +from splitio.sync.split import SplitSynchronizer, SplitSynchronizerAsync, LocalSplitSynchronizer, LocalhostMode, LocalSplitSynchronizerAsync +from splitio.sync.segment import SegmentSynchronizer, SegmentSynchronizerAsync, LocalSegmentSynchronizer, LocalSegmentSynchronizerAsync +from splitio.sync.impression import ImpressionSynchronizer, ImpressionSynchronizerAsync, ImpressionsCountSynchronizer, ImpressionsCountSynchronizerAsync +from splitio.sync.event import EventSynchronizer, EventSynchronizerAsync +from splitio.storage import SegmentStorage, SplitStorage, RuleBasedSegmentsStorage +from splitio.api import APIException, APIUriException from splitio.models.splits import Split from splitio.models.segments import Segment +from splitio.storage.inmemmory import InMemorySegmentStorage, InMemorySplitStorage, InMemorySegmentStorageAsync, InMemorySplitStorageAsync, \ + InMemoryRuleBasedSegmentStorage, InMemoryRuleBasedSegmentStorageAsync +splits = [{ + 'changeNumber': 123, + 'trafficTypeName': 'user', + 'name': 'some_name', + 'trafficAllocation': 100, + 'trafficAllocationSeed': 123456, + 'seed': 321654, + 'status': 'ACTIVE', + 'killed': False, + 'defaultTreatment': 'off', + 'algo': 2, + 'conditions': [{ + 'conditionType': 'WHITELIST', + 'matcherGroup':{ + 'combiner': 'AND', + 'matchers':[{ + 'matcherType': 'IN_SEGMENT', + 'negate': False, + 'userDefinedSegmentMatcherData': { + 'segmentName': 'segmentA' + } + }] + }, + 'partitions': [{ + 'size': 100, + 'treatment': 'on' + }] + }] +}] class SynchronizerTests(object): def test_sync_all_failed_splits(self, mocker): api = mocker.Mock() storage = mocker.Mock() + class flag_set_filter(): + def should_filter(): + return False + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + storage.flag_set_filter.sorted_flag_sets = [] - def run(x, c): + def run(x, y, c): raise APIException("something broke") api.fetch_splits.side_effect = run - split_sync = SplitSynchronizer(api, storage) + split_sync = SplitSynchronizer(api, storage, mocker.Mock()) split_synchronizers = SplitSynchronizers(split_sync, mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) sychronizer = Synchronizer(split_synchronizers, mocker.Mock(spec=SplitTasks)) sychronizer.synchronize_splits(None) # APIExceptions are handled locally and should not be propagated! - sychronizer.sync_all() # sync_all should not throw! + # test forcing to have only one retry attempt and then exit + sychronizer.sync_all(1) # sync_all should not throw! + + def test_sync_all_failed_splits_with_flagsets(self, mocker): + api = mocker.Mock() + storage = mocker.Mock() + class flag_set_filter(): + def should_filter(): + return False + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + storage.flag_set_filter.sorted_flag_sets = [] + + def run(x, y, c): + raise APIException("something broke", 414) + api.fetch_splits.side_effect = run + + split_sync = SplitSynchronizer(api, storage, mocker.Mock()) + split_synchronizers = SplitSynchronizers(split_sync, mocker.Mock(), mocker.Mock(), + mocker.Mock(), mocker.Mock()) + synchronizer = Synchronizer(split_synchronizers, mocker.Mock(spec=SplitTasks)) + + synchronizer.synchronize_splits(None) + synchronizer.sync_all(3) + assert synchronizer._backoff._attempt == 0 def test_sync_all_failed_segments(self, mocker): api = mocker.Mock() storage = mocker.Mock() split_storage = mocker.Mock(spec=SplitStorage) split_storage.get_segment_names.return_value = ['segmentA'] + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + rbs_storage.get_segment_names.return_value = [] split_sync = mocker.Mock(spec=SplitSynchronizer) split_sync.synchronize_splits.return_value = None - def run(x, y): + def run(x, y, c): raise APIException("something broke") api.fetch_segment.side_effect = run - segment_sync = SegmentSynchronizer(api, split_storage, storage) + segment_sync = SegmentSynchronizer(api, split_storage, storage, rbs_storage) split_synchronizers = SplitSynchronizers(split_sync, segment_sync, mocker.Mock(), mocker.Mock(), mocker.Mock()) sychronizer = Synchronizer(split_synchronizers, mocker.Mock(spec=SplitTasks)) - sychronizer.sync_all() # SyncAll should not throw! + sychronizer.sync_all(1) # SyncAll should not throw! assert not sychronizer._synchronize_segments() - splits = [{ - 'changeNumber': 123, - 'trafficTypeName': 'user', - 'name': 'some_name', - 'trafficAllocation': 100, - 'trafficAllocationSeed': 123456, - 'seed': 321654, - 'status': 'ACTIVE', - 'killed': False, - 'defaultTreatment': 'off', - 'algo': 2, - 'conditions': [] - }] + def test_synchronize_splits(self, mocker): + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorage(events_queue) + split_api = mocker.Mock() + split_api.fetch_splits.return_value = {'ff': {'d': splits, 's': 123, + 't': 123}, 'rbs': {'d': [], 's': -1, 't': -1}} + split_sync = SplitSynchronizer(split_api, split_storage, rbs_storage) + segment_storage = InMemorySegmentStorage(events_queue) + segment_api = mocker.Mock() + segment_api.fetch_segment.return_value = {'name': 'segmentA', 'added': ['key1', 'key2', + 'key3'], 'removed': [], 'since': 123, 'till': 123} + segment_sync = SegmentSynchronizer(segment_api, split_storage, segment_storage, rbs_storage) + split_synchronizers = SplitSynchronizers(split_sync, segment_sync, mocker.Mock(), + mocker.Mock(), mocker.Mock()) + synchronizer = Synchronizer(split_synchronizers, mocker.Mock(spec=SplitTasks)) + + synchronizer.synchronize_splits(123) + + inserted_split = split_storage.get('some_name') + assert isinstance(inserted_split, Split) + assert inserted_split.name == 'some_name' + + if not segment_sync._worker_pool.wait_for_completion(): + inserted_segment = segment_storage.get('segmentA') + assert inserted_segment.name == 'segmentA' + assert inserted_segment.keys == {'key1', 'key2', 'key3'} + + def test_synchronize_splits_calling_segment_sync_once(self, mocker): + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorage(events_queue) + split_api = mocker.Mock() + split_api.fetch_splits.return_value = {'ff': {'d': splits, 's': 123, + 't': 123}, 'rbs': {'d': [], 's': -1, 't': -1}} + + split_sync = SplitSynchronizer(split_api, split_storage, rbs_storage) + counts = {'segments': 0} + + def sync_segments(*_): + """Sync Segments.""" + counts['segments'] += 1 + return True + + segment_sync = mocker.Mock() + segment_sync.synchronize_segments.side_effect = sync_segments + segment_sync.segment_exist_in_storage.return_value = False + split_synchronizers = SplitSynchronizers(split_sync, segment_sync, mocker.Mock(), + mocker.Mock(), mocker.Mock()) + synchronizer = Synchronizer(split_synchronizers, mocker.Mock(spec=SplitTasks)) + synchronizer.synchronize_splits(123, True) + + assert counts['segments'] == 1 def test_sync_all(self, mocker): split_storage = mocker.Mock(spec=SplitStorage) + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + rbs_storage.get_segment_names.return_value = [] split_storage.get_change_number.return_value = 123 split_storage.get_segment_names.return_value = ['segmentA'] + class flag_set_filter(): + def should_filter(): + return False + def intersect(sets): + return True + split_storage.flag_set_filter = flag_set_filter + split_storage.flag_set_filter.flag_sets = {} + split_storage.flag_set_filter.sorted_flag_sets = [] + split_api = mocker.Mock() - split_api.fetch_splits.return_value = {'splits': self.splits, 'since': 123, - 'till': 123} - split_sync = SplitSynchronizer(split_api, split_storage) + split_api.fetch_splits.return_value = {'ff': {'d': splits, 's': 123, + 't': 123}, 'rbs': {'d': [], 's': -1, 't': -1}} + split_sync = SplitSynchronizer(split_api, split_storage, rbs_storage) segment_storage = mocker.Mock(spec=SegmentStorage) segment_storage.get_change_number.return_value = 123 segment_api = mocker.Mock() segment_api.fetch_segment.return_value = {'name': 'segmentA', 'added': ['key1', 'key2', 'key3'], 'removed': [], 'since': 123, 'till': 123} - segment_sync = SegmentSynchronizer(segment_api, split_storage, segment_storage) + segment_sync = SegmentSynchronizer(segment_api, split_storage, segment_storage, rbs_storage) split_synchronizers = SplitSynchronizers(split_sync, segment_sync, mocker.Mock(), mocker.Mock(), mocker.Mock()) synchronizer = Synchronizer(split_synchronizers, mocker.Mock(spec=SplitTasks)) + self.clear = False + def clear(): + self.clear = True + split_storage.clear = clear + rbs_storage.clear = clear + synchronizer.sync_all() - inserted_split = split_storage.put.mock_calls[0][1][0] + inserted_split = split_storage.update.mock_calls[0][1][0][0] assert isinstance(inserted_split, Split) assert inserted_split.name == 'some_name' @@ -130,14 +257,18 @@ def test_start_periodic_data_recording(self, mocker): impression_task = mocker.Mock(spec=ImpressionsSyncTask) impression_count_task = mocker.Mock(spec=ImpressionsCountSyncTask) event_task = mocker.Mock(spec=EventsSyncTask) + unique_keys_task = mocker.Mock(spec=UniqueKeysSyncTask) + clear_filter_task = mocker.Mock(spec=ClearFilterSyncTask) split_tasks = SplitTasks(mocker.Mock(), mocker.Mock(), impression_task, event_task, - impression_count_task) + impression_count_task, unique_keys_task, clear_filter_task) synchronizer = Synchronizer(mocker.Mock(spec=SplitSynchronizers), split_tasks) synchronizer.start_periodic_data_recording() assert len(impression_task.start.mock_calls) == 1 assert len(impression_count_task.start.mock_calls) == 1 assert len(event_task.start.mock_calls) == 1 + assert len(unique_keys_task.start.mock_calls) == 1 + assert len(clear_filter_task.start.mock_calls) == 1 def test_stop_periodic_data_recording(self, mocker): @@ -154,14 +285,20 @@ def stop_mock_2(): impression_count_task.stop.side_effect = stop_mock event_task = mocker.Mock(spec=EventsSyncTask) event_task.stop.side_effect = stop_mock + unique_keys_task = mocker.Mock(spec=UniqueKeysSyncTask) + unique_keys_task.stop.side_effect = stop_mock + clear_filter_task = mocker.Mock(spec=ClearFilterSyncTask) + clear_filter_task.stop.side_effect = stop_mock split_tasks = SplitTasks(mocker.Mock(), mocker.Mock(), impression_task, event_task, - impression_count_task) + impression_count_task, unique_keys_task, clear_filter_task) synchronizer = Synchronizer(mocker.Mock(spec=SplitSynchronizers), split_tasks) synchronizer.stop_periodic_data_recording(True) assert len(impression_task.stop.mock_calls) == 1 assert len(impression_count_task.stop.mock_calls) == 1 assert len(event_task.stop.mock_calls) == 1 + assert len(unique_keys_task.stop.mock_calls) == 1 + assert len(clear_filter_task.stop.mock_calls) == 1 def test_shutdown(self, mocker): @@ -182,13 +319,17 @@ def stop_mock_2(): impression_count_task.stop.side_effect = stop_mock event_task = mocker.Mock(spec=EventsSyncTask) event_task.stop.side_effect = stop_mock + unique_keys_task = mocker.Mock(spec=UniqueKeysSyncTask) + unique_keys_task.stop.side_effect = stop_mock + clear_filter_task = mocker.Mock(spec=ClearFilterSyncTask) + clear_filter_task.stop.side_effect = stop_mock segment_sync = mocker.Mock(spec=SegmentSynchronizer) split_synchronizers = SplitSynchronizers(mocker.Mock(), segment_sync, mocker.Mock(), - mocker.Mock(), mocker.Mock()) + mocker.Mock(), mocker.Mock(), mocker.Mock()) split_tasks = SplitTasks(split_task, segment_task, impression_task, event_task, - impression_count_task) + impression_count_task, unique_keys_task, clear_filter_task) synchronizer = Synchronizer(split_synchronizers, split_tasks) synchronizer.shutdown(True) @@ -198,6 +339,8 @@ def stop_mock_2(): assert len(impression_task.stop.mock_calls) == 1 assert len(impression_count_task.stop.mock_calls) == 1 assert len(event_task.stop.mock_calls) == 1 + assert len(unique_keys_task.stop.mock_calls) == 1 + assert len(clear_filter_task.stop.mock_calls) == 1 def test_sync_all_ok(self, mocker): """Test that 3 attempts are done before failing.""" @@ -207,7 +350,7 @@ def test_sync_all_ok(self, mocker): def sync_splits(*_): """Sync Splits.""" counts['splits'] += 1 - return True + return [] def sync_segments(*_): """Sync Segments.""" @@ -236,7 +379,7 @@ def sync_splits(*_): split_tasks = mocker.Mock(spec=SplitTasks) synchronizer = Synchronizer(split_synchronizers, split_tasks) - synchronizer.sync_all() + synchronizer.sync_all(2) assert counts['splits'] == 3 def test_sync_all_segment_attempts(self, mocker): @@ -254,5 +397,761 @@ def sync_segments(*_): split_tasks = mocker.Mock(spec=SplitTasks) synchronizer = Synchronizer(split_synchronizers, split_tasks) - synchronizer.sync_all() + synchronizer._synchronize_segments() + assert counts['segments'] == 1 + + +class SynchronizerAsyncTests(object): + + @pytest.mark.asyncio + async def test_sync_all_failed_splits(self, mocker): + api = mocker.Mock() + storage = mocker.Mock() + rbs_storage = mocker.Mock() + class flag_set_filter(): + def should_filter(): + return False + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + storage.flag_set_filter.sorted_flag_sets = [] + + async def run(x, y, c): + raise APIException("something broke") + api.fetch_splits = run + + async def get_change_number(): + return 1234 + storage.get_change_number = get_change_number + rbs_storage.get_change_number = get_change_number + + split_sync = SplitSynchronizerAsync(api, storage, rbs_storage) + split_synchronizers = SplitSynchronizers(split_sync, mocker.Mock(), mocker.Mock(), + mocker.Mock(), mocker.Mock()) + sychronizer = SynchronizerAsync(split_synchronizers, mocker.Mock(spec=SplitTasks)) + + await sychronizer.synchronize_splits(None) # APIExceptions are handled locally and should not be propagated! + + # test forcing to have only one retry attempt and then exit + await sychronizer.sync_all(1) # sync_all should not throw! + + @pytest.mark.asyncio + async def test_sync_all_failed_splits_with_flagsets(self, mocker): + api = mocker.Mock() + storage = mocker.Mock() + rbs_storage = mocker.Mock() + class flag_set_filter(): + def should_filter(): + return False + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + storage.flag_set_filter.sorted_flag_sets = [] + + async def get_change_number(): + pass + storage.get_change_number = get_change_number + rbs_storage.get_change_number = get_change_number + + async def run(x, y, c): + raise APIException("something broke", 414) + api.fetch_splits = run + + split_sync = SplitSynchronizerAsync(api, storage, rbs_storage) + split_synchronizers = SplitSynchronizers(split_sync, mocker.Mock(), mocker.Mock(), + mocker.Mock(), mocker.Mock()) + synchronizer = SynchronizerAsync(split_synchronizers, mocker.Mock(spec=SplitTasks)) + + await synchronizer.synchronize_splits(None) # APIExceptions are handled locally and should not be propagated! + + # test forcing to have only one retry attempt and then exit + await synchronizer.sync_all(3) # sync_all should not throw! + assert synchronizer._backoff._attempt == 0 + + @pytest.mark.asyncio + async def test_sync_all_failed_segments(self, mocker): + api = mocker.Mock() + storage = mocker.Mock() + split_storage = mocker.Mock(spec=SplitStorage) + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + split_sync = mocker.Mock(spec=SplitSynchronizer) + split_sync.synchronize_splits.return_value = None + + async def run(x, y, c): + raise APIException("something broke") + api.fetch_segment = run + + async def get_segment_names(): + return ['seg'] + split_storage.get_segment_names = get_segment_names + + async def get_segment_names_rbs(): + return [] + rbs_storage.get_segment_names = get_segment_names_rbs + + segment_sync = SegmentSynchronizerAsync(api, split_storage, storage, rbs_storage) + split_synchronizers = SplitSynchronizers(split_sync, segment_sync, mocker.Mock(), + mocker.Mock(), mocker.Mock()) + sychronizer = SynchronizerAsync(split_synchronizers, mocker.Mock(spec=SplitTasks)) + + await sychronizer.sync_all(1) # SyncAll should not throw! + assert not await sychronizer._synchronize_segments() + await segment_sync.shutdown() + + @pytest.mark.asyncio + async def test_synchronize_splits(self, mocker): + internal_events_queue = asyncio.Queue() + split_storage = InMemorySplitStorageAsync(internal_events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + split_api = mocker.Mock() + + async def fetch_splits(change, rb, options): + return {'ff': {'d': splits, 's': 123, + 't': 123}, 'rbs': {'d': [], 's': -1, 't': -1}} + + split_api.fetch_splits = fetch_splits + + split_sync = SplitSynchronizerAsync(split_api, split_storage, rbs_storage) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + segment_api = mocker.Mock() + + async def get_change_number(): + return 123 + split_storage.get_change_number = get_change_number + + async def fetch_segment(segment_name, change, options): + return {'name': 'segmentA', 'added': ['key1', 'key2', + 'key3'], 'removed': [], 'since': 123, 'till': 123} + segment_api.fetch_segment = fetch_segment + + segment_sync = SegmentSynchronizerAsync(segment_api, split_storage, segment_storage, rbs_storage) + split_synchronizers = SplitSynchronizers(split_sync, segment_sync, mocker.Mock(), + mocker.Mock(), mocker.Mock()) + synchronizer = SynchronizerAsync(split_synchronizers, mocker.Mock(spec=SplitTasks)) + + await synchronizer.synchronize_splits(123) + + inserted_split = await split_storage.get('some_name') + assert isinstance(inserted_split, Split) + assert inserted_split.name == 'some_name' + + await segment_sync._jobs.await_completion() + inserted_segment = await segment_storage.get('segmentA') + assert inserted_segment.name == 'segmentA' + assert inserted_segment.keys == {'key1', 'key2', 'key3'} + + await segment_sync.shutdown() + + @pytest.mark.asyncio + async def test_synchronize_splits_calling_segment_sync_once(self, mocker): + internal_events_queue = asyncio.Queue() + split_storage = InMemorySplitStorageAsync(internal_events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + async def get_change_number(): + return 123 + split_storage.get_change_number = get_change_number + + split_api = mocker.Mock() + async def fetch_splits(change, rb, options): + return {'ff': {'d': splits, 's': 123, + 't': 123}, 'rbs': {'d': [], 's': -1, 't': -1}} + split_api.fetch_splits = fetch_splits + + split_sync = SplitSynchronizerAsync(split_api, split_storage, rbs_storage) + counts = {'segments': 0} + + segment_sync = mocker.Mock() + async def sync_segments(*_): + """Sync Segments.""" + counts['segments'] += 1 + return True + segment_sync.synchronize_segments = sync_segments + + async def segment_exist_in_storage(segment): + return False + segment_sync.segment_exist_in_storage = segment_exist_in_storage + + split_synchronizers = SplitSynchronizers(split_sync, segment_sync, mocker.Mock(), + mocker.Mock(), mocker.Mock()) + synchronizer = SynchronizerAsync(split_synchronizers, mocker.Mock(spec=SplitTasks)) + await synchronizer.synchronize_splits(123, True) + + assert counts['segments'] == 1 + + @pytest.mark.asyncio + async def test_sync_all(self, mocker): + internal_events_queue = asyncio.Queue() + split_storage = InMemorySplitStorageAsync(internal_events_queue) + rbs_storage = InMemoryRuleBasedSegmentStorageAsync(internal_events_queue) + async def get_change_number(): + return 123 + split_storage.get_change_number = get_change_number + + self.added_split = None + async def update(split, deleted, change_number): + if len(split) > 0: + self.added_split = split + split_storage.update = update + + async def get_segment_names(): + return ['segmentA'] + split_storage.get_segment_names = get_segment_names + + class flag_set_filter(): + def should_filter(): + return False + def intersect(sets): + return True + split_storage.flag_set_filter = flag_set_filter + split_storage.flag_set_filter.flag_sets = {} + split_storage.flag_set_filter.sorted_flag_sets = [] + + split_api = mocker.Mock() + async def fetch_splits(change, rb, options): + return {'ff': {'d': splits, 's': 123, + 't': 123}, 'rbs': {'d': [], 's': -1, 't': -1}} + split_api.fetch_splits = fetch_splits + + split_sync = SplitSynchronizerAsync(split_api, split_storage, rbs_storage) + segment_storage = InMemorySegmentStorageAsync(internal_events_queue) + async def get_change_number(segment): + return 123 + segment_storage.get_change_number = get_change_number + + self.inserted_segment = [] + async def update(segment, added, removed, till): + self.inserted_segment.append(segment) + self.inserted_segment.append(added) + self.inserted_segment.append(removed) + segment_storage.update = update + + segment_api = mocker.Mock() + async def fetch_segment(segment_name, change, options): + return {'name': 'segmentA', 'added': ['key1', 'key2', 'key3'], + 'removed': [], 'since': 123, 'till': 123} + segment_api.fetch_segment = fetch_segment + + segment_sync = SegmentSynchronizerAsync(segment_api, split_storage, segment_storage, rbs_storage) + split_synchronizers = SplitSynchronizers(split_sync, segment_sync, mocker.Mock(), + mocker.Mock(), mocker.Mock()) + synchronizer = SynchronizerAsync(split_synchronizers, mocker.Mock(spec=SplitTasks)) + await synchronizer.sync_all() + await segment_sync._jobs.await_completion() + + assert isinstance(self.added_split[0], Split) + assert self.added_split[0].name == 'some_name' + + assert self.inserted_segment[0] == 'segmentA' + assert self.inserted_segment[1] == ['key1', 'key2', 'key3'] + assert self.inserted_segment[2] == [] + + @pytest.mark.asyncio + async def test_start_periodic_fetching(self, mocker): + split_task = mocker.Mock(spec=SplitSynchronizationTask) + segment_task = mocker.Mock(spec=SegmentSynchronizationTask) + split_tasks = SplitTasks(split_task, segment_task, mocker.Mock(), mocker.Mock(), + mocker.Mock()) + synchronizer = SynchronizerAsync(mocker.Mock(spec=SplitSynchronizers), split_tasks) + synchronizer.start_periodic_fetching() + + assert len(split_task.start.mock_calls) == 1 + assert len(segment_task.start.mock_calls) == 1 + + @pytest.mark.asyncio + async def test_stop_periodic_fetching(self, mocker): + split_task = mocker.Mock(spec=SplitSynchronizationTaskAsync) + segment_task = mocker.Mock(spec=SegmentSynchronizationTaskAsync) + segment_sync = mocker.Mock(spec=SegmentSynchronizerAsync) + split_synchronizers = SplitSynchronizers(mocker.Mock(), segment_sync, mocker.Mock(), + mocker.Mock(), mocker.Mock()) + split_tasks = SplitTasks(split_task, segment_task, mocker.Mock(), mocker.Mock(), + mocker.Mock()) + synchronizer = SynchronizerAsync(split_synchronizers, split_tasks) + self.split_task_stopped = 0 + async def stop_split(): + self.split_task_stopped += 1 + split_task.stop = stop_split + + self.segment_task_stopped = 0 + async def stop_segment(): + self.segment_task_stopped += 1 + segment_task.stop = stop_segment + + self.segment_sync_stopped = 0 + async def shutdown(): + self.segment_sync_stopped += 1 + segment_sync.shutdown = shutdown + + await synchronizer.stop_periodic_fetching() + + assert self.split_task_stopped == 1 + assert self.segment_task_stopped == 1 + assert self.segment_sync_stopped == 0 + + def test_start_periodic_data_recording(self, mocker): + impression_task = mocker.Mock(spec=ImpressionsSyncTaskAsync) + impression_count_task = mocker.Mock(spec=ImpressionsCountSyncTaskAsync) + event_task = mocker.Mock(spec=EventsSyncTaskAsync) + unique_keys_task = mocker.Mock(spec=UniqueKeysSyncTaskAsync) + clear_filter_task = mocker.Mock(spec=ClearFilterSyncTaskAsync) + split_tasks = SplitTasks(None, None, impression_task, event_task, + impression_count_task, unique_keys_task, clear_filter_task) + synchronizer = SynchronizerAsync(mocker.Mock(spec=SplitSynchronizers), split_tasks) + synchronizer.start_periodic_data_recording() + + assert len(impression_task.start.mock_calls) == 1 + assert len(impression_count_task.start.mock_calls) == 1 + assert len(event_task.start.mock_calls) == 1 + +class RedisSynchronizerTests(object): + def test_start_periodic_data_recording(self, mocker): + impression_count_task = mocker.Mock(spec=ImpressionsCountSyncTask) + unique_keys_task = mocker.Mock(spec=UniqueKeysSyncTask) + clear_filter_task = mocker.Mock(spec=ClearFilterSyncTask) + split_tasks = SplitTasks(None, None, None, None, + impression_count_task, + None, + unique_keys_task, + clear_filter_task + ) + synchronizer = RedisSynchronizer(mocker.Mock(spec=SplitSynchronizers), split_tasks) + synchronizer.start_periodic_data_recording() + + assert len(impression_count_task.start.mock_calls) == 1 + assert len(unique_keys_task.start.mock_calls) == 1 + assert len(clear_filter_task.start.mock_calls) == 1 + + def test_stop_periodic_data_recording(self, mocker): + + def stop_mock(event): + event.set() + return + + impression_count_task = mocker.Mock(spec=ImpressionsCountSyncTask) + impression_count_task.stop.side_effect = stop_mock + unique_keys_task = mocker.Mock(spec=UniqueKeysSyncTask) + unique_keys_task.stop.side_effect = stop_mock + clear_filter_task = mocker.Mock(spec=ClearFilterSyncTask) + clear_filter_task.stop.side_effect = stop_mock + + split_tasks = SplitTasks(None, None, None, None, + impression_count_task, + None, + unique_keys_task, + clear_filter_task + ) + synchronizer = RedisSynchronizer(mocker.Mock(spec=SplitSynchronizers), split_tasks) + synchronizer.stop_periodic_data_recording(True) + + assert len(impression_count_task.stop.mock_calls) == 1 + assert len(unique_keys_task.stop.mock_calls) == 1 + assert len(clear_filter_task.stop.mock_calls) == 1 + + def test_shutdown(self, mocker): + + def stop_mock(event): + event.set() + return + + impression_count_task = mocker.Mock(spec=ImpressionsCountSyncTask) + impression_count_task.stop.side_effect = stop_mock + unique_keys_task = mocker.Mock(spec=UniqueKeysSyncTask) + unique_keys_task.stop.side_effect = stop_mock + clear_filter_task = mocker.Mock(spec=ClearFilterSyncTask) + clear_filter_task.stop.side_effect = stop_mock + + segment_sync = mocker.Mock(spec=SegmentSynchronizer) + + split_tasks = SplitTasks(None, None, None, None, + impression_count_task, + None, + unique_keys_task, + clear_filter_task + ) + synchronizer = RedisSynchronizer(mocker.Mock(spec=SplitSynchronizers), split_tasks) + synchronizer.shutdown(True) + + assert len(impression_count_task.stop.mock_calls) == 1 + assert len(unique_keys_task.stop.mock_calls) == 1 + assert len(clear_filter_task.stop.mock_calls) == 1 + +class RedisSynchronizerAsyncTests(object): + @pytest.mark.asyncio + async def test_start_periodic_data_recording(self, mocker): + impression_count_task = mocker.Mock(spec=ImpressionsCountSyncTaskAsync) + unique_keys_task = mocker.Mock(spec=UniqueKeysSyncTaskAsync) + clear_filter_task = mocker.Mock(spec=ClearFilterSyncTaskAsync) + split_tasks = SplitTasks(None, None, None, None, + impression_count_task, + None, + unique_keys_task, + clear_filter_task + ) + synchronizer = RedisSynchronizerAsync(mocker.Mock(spec=SplitSynchronizers), split_tasks) + synchronizer.start_periodic_data_recording() + + assert len(impression_count_task.start.mock_calls) == 1 + assert len(unique_keys_task.start.mock_calls) == 1 + assert len(clear_filter_task.start.mock_calls) == 1 + + @pytest.mark.asyncio + async def test_stop_periodic_data_recording(self, mocker): + impression_task = mocker.Mock(spec=ImpressionsSyncTaskAsync) + self.stop_imp_calls = 0 + async def stop_imp(arg=None): + self.stop_imp_calls += 1 + return + impression_task.stop = stop_imp + + impression_count_task = mocker.Mock(spec=ImpressionsCountSyncTaskAsync) + self.stop_imp_count_calls = 0 + async def stop_imp_count(arg=None): + self.stop_imp_count_calls += 1 + return + impression_count_task.stop = stop_imp_count + + event_task = mocker.Mock(spec=EventsSyncTaskAsync) + self.stop_event_calls = 0 + async def stop_event(arg=None): + self.stop_event_calls += 1 + return + event_task.stop = stop_event + + unique_keys_task = mocker.Mock(spec=UniqueKeysSyncTaskAsync) + self.stop_unique_keys_calls = 0 + async def stop_unique_keys(arg=None): + self.stop_unique_keys_calls += 1 + return + unique_keys_task.stop = stop_unique_keys + + clear_filter_task = mocker.Mock(spec=ClearFilterSyncTaskAsync) + self.stop_clear_filter_calls = 0 + async def stop_clear_filter(arg=None): + self.stop_clear_filter_calls += 1 + return + clear_filter_task.stop = stop_clear_filter + + split_tasks = SplitTasks(mocker.Mock(), mocker.Mock(), impression_task, event_task, + impression_count_task, unique_keys_task, clear_filter_task) + synchronizer = SynchronizerAsync(mocker.Mock(spec=SplitSynchronizers), split_tasks) + await synchronizer.stop_periodic_data_recording(True) + + assert self.stop_imp_count_calls == 1 + assert self.stop_imp_calls == 1 + assert self.stop_event_calls == 1 + assert self.stop_unique_keys_calls == 1 + assert self.stop_clear_filter_calls == 1 + + @pytest.mark.asyncio + async def test_shutdown(self, mocker): + split_task = mocker.Mock(spec=SplitSynchronizationTask) + self.split_task_stopped = 0 + async def stop_split(): + self.split_task_stopped += 1 + split_task.stop = stop_split + + segment_task = mocker.Mock(spec=SegmentSynchronizationTask) + self.segment_task_stopped = 0 + async def stop_segment(): + self.segment_task_stopped += 1 + segment_task.stop = stop_segment + + impression_task = mocker.Mock(spec=ImpressionsSyncTaskAsync) + self.stop_imp_calls = 0 + async def stop_imp(arg=None): + self.stop_imp_calls += 1 + return + impression_task.stop = stop_imp + + impression_count_task = mocker.Mock(spec=ImpressionsCountSyncTaskAsync) + self.stop_imp_count_calls = 0 + async def stop_imp_count(arg=None): + self.stop_imp_count_calls += 1 + return + impression_count_task.stop = stop_imp_count + + event_task = mocker.Mock(spec=EventsSyncTaskAsync) + self.stop_event_calls = 0 + async def stop_event(arg=None): + self.stop_event_calls += 1 + return + event_task.stop = stop_event + + unique_keys_task = mocker.Mock(spec=UniqueKeysSyncTaskAsync) + self.stop_unique_keys_calls = 0 + async def stop_unique_keys(arg=None): + self.stop_unique_keys_calls += 1 + return + unique_keys_task.stop = stop_unique_keys + + clear_filter_task = mocker.Mock(spec=ClearFilterSyncTaskAsync) + self.stop_clear_filter_calls = 0 + async def stop_clear_filter(arg=None): + self.stop_clear_filter_calls += 1 + return + clear_filter_task.stop = stop_clear_filter + + segment_sync = mocker.Mock(spec=SegmentSynchronizerAsync) + self.segment_sync_stopped = 0 + async def shutdown(): + self.segment_sync_stopped += 1 + segment_sync.shutdown = shutdown + + split_synchronizers = SplitSynchronizers(mocker.Mock(), segment_sync, mocker.Mock(), + mocker.Mock(), mocker.Mock(), mocker.Mock()) + split_tasks = SplitTasks(split_task, segment_task, impression_task, event_task, + impression_count_task, unique_keys_task, clear_filter_task) + synchronizer = SynchronizerAsync(split_synchronizers, split_tasks) + await synchronizer.shutdown(True) + + assert self.split_task_stopped == 1 + assert self.segment_task_stopped == 1 + assert self.segment_sync_stopped == 1 + assert self.stop_imp_count_calls == 1 + assert self.stop_imp_calls == 1 + assert self.stop_event_calls == 1 + assert self.stop_unique_keys_calls == 1 + assert self.stop_clear_filter_calls == 1 + + @pytest.mark.asyncio + async def test_sync_all_ok(self, mocker): + """Test that 3 attempts are done before failing.""" + split_synchronizers = mocker.Mock(spec=SplitSynchronizers) + counts = {'splits': 0, 'segments': 0} + + async def sync_splits(*_): + """Sync Splits.""" + counts['splits'] += 1 + return [] + + async def sync_segments(*_): + """Sync Segments.""" + counts['segments'] += 1 + return True + + split_synchronizers.split_sync.synchronize_splits = sync_splits + split_synchronizers.segment_sync.synchronize_segments = sync_segments + split_tasks = mocker.Mock(spec=SplitTasks) + synchronizer = SynchronizerAsync(split_synchronizers, split_tasks) + + await synchronizer.sync_all() + assert counts['splits'] == 1 assert counts['segments'] == 1 + + @pytest.mark.asyncio + async def test_sync_all_split_attempts(self, mocker): + """Test that 3 attempts are done before failing.""" + split_synchronizers = mocker.Mock(spec=SplitSynchronizers) + counts = {'splits': 0, 'segments': 0} + async def sync_splits(*_): + """Sync Splits.""" + counts['splits'] += 1 + raise Exception('sarasa') + + split_synchronizers.split_sync.synchronize_splits = sync_splits + split_tasks = mocker.Mock(spec=SplitTasks) + synchronizer = SynchronizerAsync(split_synchronizers, split_tasks) + + await synchronizer.sync_all(2) + assert counts['splits'] == 3 + + @pytest.mark.asyncio + async def test_sync_all_segment_attempts(self, mocker): + """Test that segments don't trigger retries.""" + split_synchronizers = mocker.Mock(spec=SplitSynchronizers) + counts = {'splits': 0, 'segments': 0} + + async def sync_segments(*_): + """Sync Segments.""" + counts['segments'] += 1 + return False + split_synchronizers.segment_sync.synchronize_segments = sync_segments + + split_tasks = mocker.Mock(spec=SplitTasks) + synchronizer = SynchronizerAsync(split_synchronizers, split_tasks) + + await synchronizer._synchronize_segments() + assert counts['segments'] == 1 + + impression_count_task = mocker.Mock(spec=ImpressionsCountSyncTaskAsync) + self.imp_count_calls = 0 + async def imp_count_stop_mock(): + self.imp_count_calls += 1 + impression_count_task.stop = imp_count_stop_mock + + unique_keys_task = mocker.Mock(spec=UniqueKeysSyncTaskAsync) + self.unique_keys_calls = 0 + async def unique_keys_stop_mock(): + self.unique_keys_calls += 1 + unique_keys_task.stop = unique_keys_stop_mock + + clear_filter_task = mocker.Mock(spec=ClearFilterSyncTaskAsync) + self.clear_filter_calls = 0 + async def clear_filter_stop_mock(): + self.clear_filter_calls += 1 + clear_filter_task.stop = clear_filter_stop_mock + + split_tasks = SplitTasks(None, None, None, None, + impression_count_task, + None, + unique_keys_task, + clear_filter_task + ) + synchronizer = RedisSynchronizerAsync(mocker.Mock(spec=SplitSynchronizers), split_tasks) + await synchronizer.stop_periodic_data_recording(True) + + assert self.imp_count_calls == 1 + assert self.unique_keys_calls == 1 + assert self.clear_filter_calls == 1 + + def test_shutdown(self, mocker): + + def stop_mock(event): + event.set() + return + + impression_count_task = mocker.Mock(spec=ImpressionsCountSyncTask) + impression_count_task.stop.side_effect = stop_mock + unique_keys_task = mocker.Mock(spec=UniqueKeysSyncTask) + unique_keys_task.stop.side_effect = stop_mock + clear_filter_task = mocker.Mock(spec=ClearFilterSyncTask) + clear_filter_task.stop.side_effect = stop_mock + + segment_sync = mocker.Mock(spec=SegmentSynchronizer) + + split_tasks = SplitTasks(None, None, None, None, + impression_count_task, + None, + unique_keys_task, + clear_filter_task + ) + synchronizer = RedisSynchronizer(mocker.Mock(spec=SplitSynchronizers), split_tasks) + synchronizer.shutdown(True) + + assert len(impression_count_task.stop.mock_calls) == 1 + assert len(unique_keys_task.stop.mock_calls) == 1 + assert len(clear_filter_task.stop.mock_calls) == 1 + + +class LocalhostSynchronizerTests(object): + + @mock.patch('splitio.sync.segment.LocalSegmentSynchronizer.synchronize_segments') + def test_synchronize_splits(self, mocker): + split_sync = LocalSplitSynchronizer(mocker.Mock(), mocker.Mock(), mocker.Mock()) + segment_sync = LocalSegmentSynchronizer(mocker.Mock(), mocker.Mock(), mocker.Mock()) + synchronizers = SplitSynchronizers(split_sync, segment_sync, None, None, None) + local_synchronizer = LocalhostSynchronizer(synchronizers, mocker.Mock(), mocker.Mock()) + + def synchronize_splits(*args, **kwargs): + return ["segmentA", "segmentB"] + split_sync.synchronize_splits = synchronize_splits + + def segment_exist_in_storage(*args, **kwargs): + return False + segment_sync.segment_exist_in_storage = segment_exist_in_storage + + assert(local_synchronizer.synchronize_splits()) + assert(mocker.called) + + def test_start_and_stop_tasks(self, mocker): + synchronizers = SplitSynchronizers( + LocalSplitSynchronizer(mocker.Mock(), mocker.Mock(), mocker.Mock()), + LocalSegmentSynchronizer(mocker.Mock(), mocker.Mock(), mocker.Mock()), None, None, None) + split_task = SplitSynchronizationTask(synchronizers.split_sync.synchronize_splits, 30) + segment_task = SegmentSynchronizationTask(synchronizers.segment_sync.synchronize_segments, 30) + tasks = SplitTasks(split_task, segment_task, None, None, None,) + + self.split_task_start_called = False + def split_task_start(*args, **kwargs): + self.split_task_start_called = True + split_task.start = split_task_start + + self.segment_task_start_called = False + def segment_task_start(*args, **kwargs): + self.segment_task_start_called = True + segment_task.start = segment_task_start + + self.split_task_stop_called = False + def split_task_stop(*args, **kwargs): + self.split_task_stop_called = True + split_task.stop = split_task_stop + + self.segment_task_stop_called = False + def segment_task_stop(*args, **kwargs): + self.segment_task_stop_called = True + segment_task.stop = segment_task_stop + + local_synchronizer = LocalhostSynchronizer(synchronizers, tasks, LocalhostMode.JSON) + local_synchronizer.start_periodic_fetching() + assert(self.split_task_start_called) + assert(self.segment_task_start_called) + + local_synchronizer.stop_periodic_fetching() + assert(self.split_task_stop_called) + assert(self.segment_task_stop_called) + + +class LocalhostSynchronizerAsyncTests(object): + + @pytest.mark.asyncio + async def test_synchronize_splits(self, mocker): + split_sync = LocalSplitSynchronizerAsync(mocker.Mock(), mocker.Mock(), mocker.Mock()) + segment_sync = LocalSegmentSynchronizerAsync(mocker.Mock(), mocker.Mock(), mocker.Mock()) + synchronizers = SplitSynchronizers(split_sync, segment_sync, None, None, None) + local_synchronizer = LocalhostSynchronizerAsync(synchronizers, mocker.Mock(), mocker.Mock()) + + self.called = False + async def synchronize_segments(*args): + self.called = True + segment_sync.synchronize_segments = synchronize_segments + + async def synchronize_splits(*args, **kwargs): + return ["segmentA", "segmentB"] + split_sync.synchronize_splits = synchronize_splits + + async def segment_exist_in_storage(*args, **kwargs): + return False + segment_sync.segment_exist_in_storage = segment_exist_in_storage + + assert(await local_synchronizer.synchronize_splits()) + assert(self.called) + + @pytest.mark.asyncio + async def test_start_and_stop_tasks(self, mocker): + synchronizers = SplitSynchronizers( + LocalSplitSynchronizerAsync(mocker.Mock(), mocker.Mock(), mocker.Mock()), + LocalSegmentSynchronizerAsync(mocker.Mock(), mocker.Mock(), mocker.Mock()), None, None, None) + split_task = SplitSynchronizationTaskAsync(synchronizers.split_sync.synchronize_splits, 30) + segment_task = SegmentSynchronizationTaskAsync(synchronizers.segment_sync.synchronize_segments, 30) + tasks = SplitTasks(split_task, segment_task, None, None, None,) + + self.split_task_start_called = False + def split_task_start(*args, **kwargs): + self.split_task_start_called = True + split_task.start = split_task_start + + self.segment_task_start_called = False + def segment_task_start(*args, **kwargs): + self.segment_task_start_called = True + segment_task.start = segment_task_start + + self.split_task_stop_called = False + async def split_task_stop(*args, **kwargs): + self.split_task_stop_called = True + split_task.stop = split_task_stop + + self.segment_task_stop_called = False + async def segment_task_stop(*args, **kwargs): + self.segment_task_stop_called = True + segment_task.stop = segment_task_stop + + local_synchronizer = LocalhostSynchronizerAsync(synchronizers, tasks, LocalhostMode.JSON) + local_synchronizer.start_periodic_fetching() + assert(self.split_task_start_called) + assert(self.segment_task_start_called) + + await local_synchronizer.stop_periodic_fetching() + assert(self.split_task_stop_called) + assert(self.segment_task_stop_called) diff --git a/tests/sync/test_telemetry.py b/tests/sync/test_telemetry.py new file mode 100644 index 00000000..dd8119e2 --- /dev/null +++ b/tests/sync/test_telemetry.py @@ -0,0 +1,302 @@ +"""Telemetry Worker tests.""" +import unittest.mock as mock +import pytest +import queue +import asyncio + +from splitio.sync.telemetry import TelemetrySynchronizer, TelemetrySynchronizerAsync, InMemoryTelemetrySubmitter, InMemoryTelemetrySubmitterAsync +from splitio.engine.telemetry import TelemetryStorageConsumer, TelemetryStorageConsumerAsync +from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemoryTelemetryStorageAsync, InMemorySegmentStorage, InMemorySegmentStorageAsync, InMemorySplitStorage, InMemorySplitStorageAsync +from splitio.models.splits import Split, Status +from splitio.models.segments import Segment +from splitio.models.telemetry import StreamingEvents, StreamingEventsAsync +from splitio.api.telemetry import TelemetryAPI + +class TelemetrySynchronizerTests(object): + """Telemetry synchronizer test cases.""" + + @mock.patch('splitio.sync.telemetry.InMemoryTelemetrySubmitter.synchronize_config') + def test_synchronize_config(self, mocker): + telemetry_synchronizer = TelemetrySynchronizer(InMemoryTelemetrySubmitter(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock())) + telemetry_synchronizer.synchronize_config() + assert(mocker.called) + + @mock.patch('splitio.sync.telemetry.InMemoryTelemetrySubmitter.synchronize_stats') + def test_synchronize_stats(self, mocker): + telemetry_synchronizer = TelemetrySynchronizer(InMemoryTelemetrySubmitter(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock())) + telemetry_synchronizer.synchronize_stats() + assert(mocker.called) + + +class TelemetrySynchronizerAsyncTests(object): + """Telemetry synchronizer async test cases.""" + + @pytest.mark.asyncio + async def test_synchronize_config(self, mocker): + telemetry_synchronizer = TelemetrySynchronizerAsync(InMemoryTelemetrySubmitterAsync(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock())) + self.called = False + async def synchronize_config(): + self.called = True + telemetry_synchronizer.synchronize_config = synchronize_config + await telemetry_synchronizer.synchronize_config() + assert(self.called) + + @pytest.mark.asyncio + async def test_synchronize_stats(self, mocker): + telemetry_synchronizer = TelemetrySynchronizer(InMemoryTelemetrySubmitter(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock())) + self.called = False + async def synchronize_stats(): + self.called = True + telemetry_synchronizer.synchronize_stats = synchronize_stats + await telemetry_synchronizer.synchronize_stats() + assert(self.called) + + +class TelemetrySubmitterTests(object): + """Telemetry submitter test cases.""" + + def test_synchronize_telemetry(self, mocker): + api = mocker.Mock(spec=TelemetryAPI) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_consumer = TelemetryStorageConsumer(telemetry_storage) + events_queue = queue.Queue() + split_storage = InMemorySplitStorage(events_queue) + split_storage.update([Split('split1', 1234, 1, False, 'user', Status.ACTIVE, 123)], [], -1) + segment_storage = InMemorySegmentStorage(events_queue) + segment_storage.put(Segment('segment1', [], 123)) + telemetry_submitter = InMemoryTelemetrySubmitter(telemetry_consumer, split_storage, segment_storage, api) + + telemetry_storage._counters._impressions_queued = 100 + telemetry_storage._counters._impressions_deduped = 30 + telemetry_storage._counters._impressions_dropped = 0 + telemetry_storage._counters._events_queued = 20 + telemetry_storage._counters._events_dropped = 10 + telemetry_storage._counters._auth_rejections = 1 + telemetry_storage._counters._token_refreshes = 3 + telemetry_storage._counters._session_length = 3 + telemetry_storage._counters._update_from_sse['sp'] = 3 + + telemetry_storage._method_exceptions._treatment = 10 + telemetry_storage._method_exceptions._treatments = 1 + telemetry_storage._method_exceptions._treatment_with_config = 5 + telemetry_storage._method_exceptions._treatments_with_config = 1 + telemetry_storage._method_exceptions._treatments_by_flag_set = 2 + telemetry_storage._method_exceptions._treatments_by_flag_sets = 3 + telemetry_storage._method_exceptions._treatments_with_config_by_flag_set = 4 + telemetry_storage._method_exceptions._treatments_with_config_by_flag_sets = 6 + telemetry_storage._method_exceptions._track = 3 + + telemetry_storage._last_synchronization._split = 5 + telemetry_storage._last_synchronization._segment = 3 + telemetry_storage._last_synchronization._impression = 10 + telemetry_storage._last_synchronization._impression_count = 0 + telemetry_storage._last_synchronization._event = 4 + telemetry_storage._last_synchronization._telemetry = 0 + telemetry_storage._last_synchronization._token = 3 + + telemetry_storage._http_sync_errors._split = {'500': 3, '501': 2} + telemetry_storage._http_sync_errors._segment = {'401': 1} + telemetry_storage._http_sync_errors._impression = {'500': 1} + telemetry_storage._http_sync_errors._impression_count = {'401': 5} + telemetry_storage._http_sync_errors._event = {'404': 10} + telemetry_storage._http_sync_errors._telemetry = {'501': 3} + telemetry_storage._http_sync_errors._token = {'505': 11} + + telemetry_storage._streaming_events = StreamingEvents() + telemetry_storage._tags = ['tag1'] + + telemetry_storage._method_latencies._treatment = [1] + [0] * 22 + telemetry_storage._method_latencies._treatments = [0] * 23 + telemetry_storage._method_latencies._treatment_with_config = [0] * 23 + telemetry_storage._method_latencies._treatments_with_config = [0] * 23 + telemetry_storage._method_latencies._treatments_by_flag_set = [1] + [0] * 22 + telemetry_storage._method_latencies._treatments_by_flag_sets = [0] * 23 + telemetry_storage._method_latencies._treatments_with_config_by_flag_set = [1] + [0] * 22 + telemetry_storage._method_latencies._treatments_with_config_by_flag_sets = [0] * 23 + telemetry_storage._method_latencies._track = [0] * 23 + + telemetry_storage._http_latencies._split = [1] + [0] * 22 + telemetry_storage._http_latencies._segment = [0] * 23 + telemetry_storage._http_latencies._impression = [0] * 23 + telemetry_storage._http_latencies._impression_count = [0] * 23 + telemetry_storage._http_latencies._event = [0] * 23 + telemetry_storage._http_latencies._telemetry = [0] * 23 + telemetry_storage._http_latencies._token = [0] * 23 + + telemetry_storage.record_config({'operationMode': 'inmemory', + 'storageType': None, + 'streamingEnabled': True, + 'impressionsQueueSize': 100, + 'eventsQueueSize': 200, + 'impressionsMode': 'DEBUG', + 'impressionListener': None, + 'featuresRefreshRate': 30, + 'segmentsRefreshRate': 30, + 'impressionsRefreshRate': 60, + 'eventsPushRate': 60, + 'metricsRefreshRate': 10, + 'activeFactoryCount': 1, + 'notReady': 0, + 'timeUntilReady': 1 + }, {}, 0, 0 + ) + self.formatted_config = "" + def record_init(*args, **kwargs): + self.formatted_config = args[0] + + api.record_init.side_effect = record_init + telemetry_submitter.synchronize_config() + assert(self.formatted_config == telemetry_submitter._telemetry_init_consumer.get_config_stats()) + + def record_stats(*args, **kwargs): + self.formatted_stats = args[0] + + api.record_stats.side_effect = record_stats + telemetry_submitter.synchronize_stats() + assert(self.formatted_stats == { + "iQ": 100, + "iDe": 30, + "iDr": 0, + "eQ": 20, + "eD": 10, + "lS": {"sp": 5, "se": 3, "im": 10, "ic": 0, "ev": 4, "te": 0, "to": 3}, + "t": ["tag1"], + "hE": {"sp": {"500": 3, "501": 2}, "se": {"401": 1}, "im": {"500": 1}, "ic": {"401": 5}, "ev": {"404": 10}, "te": {"501": 3}, "to": {"505": 11}}, + "hL": {"sp": [1] + [0] * 22, "se": [0] * 23, "im": [0] * 23, "ic": [0] * 23, "ev": [0] * 23, "te": [0] * 23, "to": [0] * 23}, + "aR": 1, + "tR": 3, + "sE": [], + "sL": 3, + "mE": {"t": 10, "ts": 1, "tc": 5, "tcs": 1, "tf": 2, "tfs": 3, "tcf": 4, "tcfs": 6, "tr": 3}, + "mL": {"t": [1] + [0] * 22, "ts": [0] * 23, "tc": [0] * 23, "tcs": [0] * 23, "tf": [1] + [0] * 22, "tfs": [0] * 23, "tcf": [1] + [0] * 22, "tcfs": [0] * 23, "tr": [0] * 23}, + "spC": 1, + "seC": 1, + "skC": 0, + "ufs": {"rbs": 0, "sp": 3}, + "t": ['tag1'] + }) + + +class TelemetrySubmitterAsyncTests(object): + """Telemetry submitter async test cases.""" + + @pytest.mark.asyncio + async def test_synchronize_telemetry(self, mocker): + api = mocker.Mock(spec=TelemetryAPI) + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_consumer = TelemetryStorageConsumerAsync(telemetry_storage) + split_storage = InMemorySplitStorageAsync(asyncio.Queue()) + await split_storage.update([Split('split1', 1234, 1, False, 'user', Status.ACTIVE, 123)], [], -1) + segment_storage = InMemorySegmentStorageAsync(asyncio.Queue()) + await segment_storage.put(Segment('segment1', [], 123)) + telemetry_submitter = InMemoryTelemetrySubmitterAsync(telemetry_consumer, split_storage, segment_storage, api) + + telemetry_storage._counters._impressions_queued = 100 + telemetry_storage._counters._impressions_deduped = 30 + telemetry_storage._counters._impressions_dropped = 0 + telemetry_storage._counters._events_queued = 20 + telemetry_storage._counters._events_dropped = 10 + telemetry_storage._counters._auth_rejections = 1 + telemetry_storage._counters._token_refreshes = 3 + telemetry_storage._counters._session_length = 3 + telemetry_storage._counters._update_from_sse['sp'] = 3 + + telemetry_storage._method_exceptions._treatment = 10 + telemetry_storage._method_exceptions._treatments = 1 + telemetry_storage._method_exceptions._treatment_with_config = 5 + telemetry_storage._method_exceptions._treatments_with_config = 1 + telemetry_storage._method_exceptions._treatments_by_flag_set = 2 + telemetry_storage._method_exceptions._treatments_by_flag_sets = 3 + telemetry_storage._method_exceptions._treatments_with_config_by_flag_set = 4 + telemetry_storage._method_exceptions._treatments_with_config_by_flag_sets = 6 + telemetry_storage._method_exceptions._track = 3 + + telemetry_storage._last_synchronization._split = 5 + telemetry_storage._last_synchronization._segment = 3 + telemetry_storage._last_synchronization._impression = 10 + telemetry_storage._last_synchronization._impression_count = 0 + telemetry_storage._last_synchronization._event = 4 + telemetry_storage._last_synchronization._telemetry = 0 + telemetry_storage._last_synchronization._token = 3 + + telemetry_storage._http_sync_errors._split = {'500': 3, '501': 2} + telemetry_storage._http_sync_errors._segment = {'401': 1} + telemetry_storage._http_sync_errors._impression = {'500': 1} + telemetry_storage._http_sync_errors._impression_count = {'401': 5} + telemetry_storage._http_sync_errors._event = {'404': 10} + telemetry_storage._http_sync_errors._telemetry = {'501': 3} + telemetry_storage._http_sync_errors._token = {'505': 11} + + telemetry_storage._streaming_events = await StreamingEventsAsync.create() + telemetry_storage._tags = ['tag1'] + + telemetry_storage._method_latencies._treatment = [1] + [0] * 22 + telemetry_storage._method_latencies._treatments = [0] * 23 + telemetry_storage._method_latencies._treatment_with_config = [0] * 23 + telemetry_storage._method_latencies._treatments_with_config = [0] * 23 + telemetry_storage._method_latencies._treatments_by_flag_set = [1] + [0] * 22 + telemetry_storage._method_latencies._treatments_by_flag_sets = [0] * 23 + telemetry_storage._method_latencies._treatments_with_config_by_flag_set = [1] + [0] * 22 + telemetry_storage._method_latencies._treatments_with_config_by_flag_sets = [0] * 23 + telemetry_storage._method_latencies._track = [0] * 23 + + telemetry_storage._http_latencies._split = [1] + [0] * 22 + telemetry_storage._http_latencies._segment = [0] * 23 + telemetry_storage._http_latencies._impression = [0] * 23 + telemetry_storage._http_latencies._impression_count = [0] * 23 + telemetry_storage._http_latencies._event = [0] * 23 + telemetry_storage._http_latencies._telemetry = [0] * 23 + telemetry_storage._http_latencies._token = [0] * 23 + + await telemetry_storage.record_config({'operationMode': 'inmemory', + 'storageType': None, + 'streamingEnabled': True, + 'impressionsQueueSize': 100, + 'eventsQueueSize': 200, + 'impressionsMode': 'DEBUG', + 'impressionListener': None, + 'featuresRefreshRate': 30, + 'segmentsRefreshRate': 30, + 'impressionsRefreshRate': 60, + 'eventsPushRate': 60, + 'metricsRefreshRate': 10, + 'activeFactoryCount': 1, + 'notReady': 0, + 'timeUntilReady': 1 + }, {}, 0, 0 + ) + self.formatted_config = "" + async def record_init(*args, **kwargs): + self.formatted_config = args[0] + api.record_init = record_init + + await telemetry_submitter.synchronize_config() + assert(self.formatted_config == await telemetry_submitter._telemetry_init_consumer.get_config_stats()) + + async def record_stats(*args, **kwargs): + self.formatted_stats = args[0] + api.record_stats = record_stats + + await telemetry_submitter.synchronize_stats() + assert(self.formatted_stats == { + "iQ": 100, + "iDe": 30, + "iDr": 0, + "eQ": 20, + "eD": 10, + "lS": {"sp": 5, "se": 3, "im": 10, "ic": 0, "ev": 4, "te": 0, "to": 3}, + "t": ["tag1"], + "hE": {"sp": {"500": 3, "501": 2}, "se": {"401": 1}, "im": {"500": 1}, "ic": {"401": 5}, "ev": {"404": 10}, "te": {"501": 3}, "to": {"505": 11}}, + "hL": {"sp": [1] + [0] * 22, "se": [0] * 23, "im": [0] * 23, "ic": [0] * 23, "ev": [0] * 23, "te": [0] * 23, "to": [0] * 23}, + "aR": 1, + "tR": 3, + "sE": [], + "sL": 3, + "mE": {"t": 10, "ts": 1, "tc": 5, "tcs": 1, "tf": 2, "tfs": 3, "tcf": 4, "tcfs": 6, "tr": 3}, + "mL": {"t": [1] + [0] * 22, "ts": [0] * 23, "tc": [0] * 23, "tcs": [0] * 23, "tf": [1] + [0] * 22, "tfs": [0] * 23, "tcf": [1] + [0] * 22, "tcfs": [0] * 23, "tr": [0] * 23}, + "spC": 1, + "seC": 1, + "skC": 0, + "ufs": {"rbs": 0, "sp": 3}, + "t": ['tag1'] + }) diff --git a/tests/sync/test_unique_keys_sync.py b/tests/sync/test_unique_keys_sync.py new file mode 100644 index 00000000..47cedaab --- /dev/null +++ b/tests/sync/test_unique_keys_sync.py @@ -0,0 +1,106 @@ +"""Split Worker tests.""" +import unittest.mock as mock +import pytest + +from splitio.engine.impressions.adapters import InMemorySenderAdapter, InMemorySenderAdapterAsync +from splitio.engine.impressions.unique_keys_tracker import UniqueKeysTracker, UniqueKeysTrackerAsync +from splitio.sync.unique_keys import UniqueKeysSynchronizer, ClearFilterSynchronizer, UniqueKeysSynchronizerAsync, ClearFilterSynchronizerAsync + +class UniqueKeysSynchronizerTests(object): + """Unique keys synchronizer test cases.""" + + def test_sync_unique_keys_chunks(self, mocker): + total_mtks = 5010 # Use number higher than 5000, which is the default max_bulk_size + unique_keys_tracker = UniqueKeysTracker() + for i in range(0 , total_mtks): + unique_keys_tracker.track('key'+str(i)+'', 'feature1') + sender_adapter = InMemorySenderAdapter(mocker.Mock()) + unique_keys_synchronizer = UniqueKeysSynchronizer(sender_adapter, unique_keys_tracker) + cache, cache_size = unique_keys_synchronizer._uniqe_keys_tracker.get_cache_info_and_pop_all() + assert(cache_size > unique_keys_synchronizer._max_bulk_size) + + bulks = unique_keys_synchronizer._split_cache_to_bulks(cache) + assert(len(bulks) == int(total_mtks / unique_keys_synchronizer._max_bulk_size) + 1) + for i in range(0 , int(total_mtks / unique_keys_synchronizer._max_bulk_size)): + if i > int(total_mtks / unique_keys_synchronizer._max_bulk_size): + assert(len(bulks[i]['feature1']) == (total_mtks - unique_keys_synchronizer._max_bulk_size)) + else: + assert(len(bulks[i]['feature1']) == unique_keys_synchronizer._max_bulk_size) + + @mock.patch('splitio.engine.impressions.adapters.InMemorySenderAdapter.record_unique_keys') + def test_sync_unique_keys_send_all(self, mtk_mocker): + mtk_mocker.side_effect = self.mocked_record_unique_keys + + total_mtks = 5010 # Use number higher than 5000, which is the default max_bulk_size + unique_keys_tracker = UniqueKeysTracker() + for i in range(0 , total_mtks): + unique_keys_tracker.track('key'+str(i)+'', 'feature1') + sender_adapter = InMemorySenderAdapter(mock.Mock()) + unique_keys_synchronizer = UniqueKeysSynchronizer(sender_adapter, unique_keys_tracker) + unique_keys_synchronizer.send_all() + assert(mtk_mocker.call_count == int(total_mtks / unique_keys_synchronizer._max_bulk_size) + 1) + + def mocked_record_unique_keys(self, cache): + return mock.Mock() + + def test_clear_all_filter(self, mocker): + unique_keys_tracker = UniqueKeysTracker() + total_mtks = 50 + for i in range(0 , total_mtks): + unique_keys_tracker.track('key'+str(i)+'', 'feature1') + + clear_filter_sync = ClearFilterSynchronizer(unique_keys_tracker) + clear_filter_sync.clear_all() + for i in range(0 , total_mtks): + assert(not unique_keys_tracker._filter.contains('feature1key'+str(i))) + + +class UniqueKeysSynchronizerAsyncTests(object): + """Unique keys synchronizer async test cases.""" + + @pytest.mark.asyncio + async def test_sync_unique_keys_chunks(self, mocker): + total_mtks = 5010 # Use number higher than 5000, which is the default max_bulk_size + unique_keys_tracker = UniqueKeysTrackerAsync() + for i in range(0 , total_mtks): + await unique_keys_tracker.track('key'+str(i)+'', 'feature1') + sender_adapter = InMemorySenderAdapterAsync(mocker.Mock()) + unique_keys_synchronizer = UniqueKeysSynchronizerAsync(sender_adapter, unique_keys_tracker) + cache, cache_size = await unique_keys_synchronizer._uniqe_keys_tracker.get_cache_info_and_pop_all() + assert(cache_size > unique_keys_synchronizer._max_bulk_size) + + bulks = unique_keys_synchronizer._split_cache_to_bulks(cache) + assert(len(bulks) == int(total_mtks / unique_keys_synchronizer._max_bulk_size) + 1) + for i in range(0 , int(total_mtks / unique_keys_synchronizer._max_bulk_size)): + if i > int(total_mtks / unique_keys_synchronizer._max_bulk_size): + assert(len(bulks[i]['feature1']) == (total_mtks - unique_keys_synchronizer._max_bulk_size)) + else: + assert(len(bulks[i]['feature1']) == unique_keys_synchronizer._max_bulk_size) + + @pytest.mark.asyncio + async def test_sync_unique_keys_send_all(self): + total_mtks = 5010 # Use number higher than 5000, which is the default max_bulk_size + unique_keys_tracker = UniqueKeysTrackerAsync() + for i in range(0 , total_mtks): + await unique_keys_tracker.track('key'+str(i)+'', 'feature1') + sender_adapter = InMemorySenderAdapterAsync(mock.Mock()) + self.call_count = 0 + async def record_unique_keys(*args): + self.call_count += 1 + + sender_adapter.record_unique_keys = record_unique_keys + unique_keys_synchronizer = UniqueKeysSynchronizerAsync(sender_adapter, unique_keys_tracker) + await unique_keys_synchronizer.send_all() + assert(self.call_count == int(total_mtks / unique_keys_synchronizer._max_bulk_size) + 1) + + @pytest.mark.asyncio + async def test_clear_all_filter(self, mocker): + unique_keys_tracker = UniqueKeysTrackerAsync() + total_mtks = 50 + for i in range(0 , total_mtks): + await unique_keys_tracker.track('key'+str(i)+'', 'feature1') + + clear_filter_sync = ClearFilterSynchronizerAsync(unique_keys_tracker) + await clear_filter_sync.clear_all() + for i in range(0 , total_mtks): + assert(not unique_keys_tracker._filter.contains('feature1key'+str(i))) \ No newline at end of file diff --git a/tests/tasks/test_events_sync.py b/tests/tasks/test_events_sync.py index ec72c883..b2ea500d 100644 --- a/tests/tasks/test_events_sync.py +++ b/tests/tasks/test_events_sync.py @@ -2,12 +2,15 @@ import threading import time +import pytest + from splitio.api.client import HttpResponse from splitio.tasks import events_sync from splitio.storage import EventStorage from splitio.models.events import Event from splitio.api.events import EventsAPI -from splitio.sync.event import EventSynchronizer +from splitio.sync.event import EventSynchronizer, EventSynchronizerAsync +from splitio.optional.loaders import asyncio class EventsSyncTests(object): @@ -26,7 +29,7 @@ def test_normal_operation(self, mocker): storage.pop_many.return_value = events api = mocker.Mock(spec=EventsAPI) - api.flush_events.return_value = HttpResponse(200, '') + api.flush_events.return_value = HttpResponse(200, '', {}) event_synchronizer = EventSynchronizer(api, storage, 5) task = events_sync.EventsSyncTask(event_synchronizer.synchronize_events, 1) task.start() @@ -40,3 +43,47 @@ def test_normal_operation(self, mocker): stop_event.wait(5) assert stop_event.is_set() assert len(api.flush_events.mock_calls) > calls_now + + +class EventsSyncAsyncTests(object): + """Impressions Syncrhonization task async test cases.""" + + @pytest.mark.asyncio + async def test_normal_operation(self, mocker): + """Test that the task works properly under normal circumstances.""" + self.events = [ + Event('key1', 'user', 'purchase', 5.3, 123456, None), + Event('key2', 'user', 'purchase', 5.3, 123456, None), + Event('key3', 'user', 'purchase', 5.3, 123456, None), + Event('key4', 'user', 'purchase', 5.3, 123456, None), + Event('key5', 'user', 'purchase', 5.3, 123456, None), + ] + storage = mocker.Mock(spec=EventStorage) + self.called = False + async def pop_many(*args): + self.called = True + return self.events + storage.pop_many = pop_many + + api = mocker.Mock(spec=EventsAPI) + self.flushed_events = None + self.count = 0 + async def flush_events(events): + self.count += 1 + self.flushed_events = events + return HttpResponse(200, '', {}) + api.flush_events = flush_events + + event_synchronizer = EventSynchronizerAsync(api, storage, 5) + task = events_sync.EventsSyncTaskAsync(event_synchronizer.synchronize_events, 1) + task.start() + await asyncio.sleep(2) + + assert task.is_running() + assert self.called + assert self.flushed_events == self.events + + calls_now = self.count + await task.stop() + assert not task.is_running() + assert self.count > calls_now diff --git a/tests/tasks/test_impressions_sync.py b/tests/tasks/test_impressions_sync.py index e81c4e29..78bbf979 100644 --- a/tests/tasks/test_impressions_sync.py +++ b/tests/tasks/test_impressions_sync.py @@ -2,32 +2,33 @@ import threading import time +import pytest + from splitio.api.client import HttpResponse from splitio.tasks import impressions_sync from splitio.storage import ImpressionStorage from splitio.models.impressions import Impression from splitio.api.impressions import ImpressionsAPI -from splitio.sync.impression import ImpressionSynchronizer, ImpressionsCountSynchronizer -from splitio.engine.impressions import Manager as ImpressionsManager -from splitio.engine.impressions import Counter - +from splitio.sync.impression import ImpressionSynchronizer, ImpressionsCountSynchronizer, ImpressionSynchronizerAsync, ImpressionsCountSynchronizerAsync +from splitio.engine.impressions.manager import Counter +from splitio.optional.loaders import asyncio -class ImpressionsSyncTests(object): +class ImpressionsSyncTaskTests(object): """Impressions Syncrhonization task test cases.""" def test_normal_operation(self, mocker): """Test that the task works properly under normal circumstances.""" storage = mocker.Mock(spec=ImpressionStorage) impressions = [ - Impression('key1', 'split1', 'on', 'l1', 123456, 'b1', 321654), - Impression('key2', 'split1', 'on', 'l1', 123456, 'b1', 321654), - Impression('key3', 'split2', 'off', 'l1', 123456, 'b1', 321654), - Impression('key4', 'split2', 'on', 'l1', 123456, 'b1', 321654), - Impression('key5', 'split3', 'off', 'l1', 123456, 'b1', 321654) + Impression('key1', 'split1', 'on', 'l1', 123456, 'b1', 321654, None, None), + Impression('key2', 'split1', 'on', 'l1', 123456, 'b1', 321654, None, None), + Impression('key3', 'split2', 'off', 'l1', 123456, 'b1', 321654, None, None), + Impression('key4', 'split2', 'on', 'l1', 123456, 'b1', 321654, None, None), + Impression('key5', 'split3', 'off', 'l1', 123456, 'b1', 321654, None, None) ] storage.pop_many.return_value = impressions api = mocker.Mock(spec=ImpressionsAPI) - api.flush_impressions.return_value = HttpResponse(200, '') + api.flush_impressions.return_value = HttpResponse(200, '', {}) impression_synchronizer = ImpressionSynchronizer(api, storage, 5) task = impressions_sync.ImpressionsSyncTask( impression_synchronizer.synchronize_impressions, @@ -46,12 +47,57 @@ def test_normal_operation(self, mocker): assert len(api.flush_impressions.mock_calls) > calls_now -class ImpressionsCountSyncTests(object): +class ImpressionsSyncTaskAsyncTests(object): + """Impressions Syncrhonization task test cases.""" + + @pytest.mark.asyncio + async def test_normal_operation(self, mocker): + """Test that the task works properly under normal circumstances.""" + storage = mocker.Mock(spec=ImpressionStorage) + impressions = [ + Impression('key1', 'split1', 'on', 'l1', 123456, 'b1', 321654, None, None), + Impression('key2', 'split1', 'on', 'l1', 123456, 'b1', 321654, None, None), + Impression('key3', 'split2', 'off', 'l1', 123456, 'b1', 321654, None, None), + Impression('key4', 'split2', 'on', 'l1', 123456, 'b1', 321654, None, None), + Impression('key5', 'split3', 'off', 'l1', 123456, 'b1', 321654, None, None) + ] + self.pop_called = 0 + async def pop_many(*args): + self.pop_called += 1 + return impressions + storage.pop_many = pop_many + + api = mocker.Mock(spec=ImpressionsAPI) + self.flushed = None + self.called = 0 + async def flush_impressions(imps): + self.called += 1 + self.flushed = imps + return HttpResponse(200, '', {}) + api.flush_impressions = flush_impressions + + impression_synchronizer = ImpressionSynchronizerAsync(api, storage, 5) + task = impressions_sync.ImpressionsSyncTaskAsync( + impression_synchronizer.synchronize_impressions, + 1 + ) + task.start() + await asyncio.sleep(2) + assert task.is_running() + assert self.pop_called == 1 + assert self.flushed == impressions + + calls_now = self.called + await task.stop() + assert self.called > calls_now + + +class ImpressionsCountSyncTaskTests(object): """Impressions Syncrhonization task test cases.""" def test_normal_operation(self, mocker): """Test that the task works properly under normal circumstances.""" - manager = mocker.Mock(spec=ImpressionsManager) + counter = mocker.Mock(spec=Counter) counters = [ Counter.CountPerFeature('f1', 123, 2), @@ -60,18 +106,18 @@ def test_normal_operation(self, mocker): Counter.CountPerFeature('f2', 456, 222) ] - manager.get_counts.return_value = counters + counter.pop_all.return_value = counters api = mocker.Mock(spec=ImpressionsAPI) - api.flush_counters.return_value = HttpResponse(200, '') + api.flush_counters.return_value = HttpResponse(200, '', {}) impressions_sync.ImpressionsCountSyncTask._PERIOD = 1 - impression_synchronizer = ImpressionsCountSynchronizer(api, manager) + impression_synchronizer = ImpressionsCountSynchronizer(api, counter) task = impressions_sync.ImpressionsCountSyncTask( impression_synchronizer.synchronize_counters ) task.start() time.sleep(2) assert task.is_running() - assert manager.get_counts.mock_calls[0] == mocker.call() + assert counter.pop_all.mock_calls[0] == mocker.call() assert api.flush_counters.mock_calls[0] == mocker.call(counters) stop_event = threading.Event() calls_now = len(api.flush_counters.mock_calls) @@ -79,3 +125,48 @@ def test_normal_operation(self, mocker): stop_event.wait(5) assert stop_event.is_set() assert len(api.flush_counters.mock_calls) > calls_now + + +class ImpressionsCountSyncTaskAsyncTests(object): + """Impressions Syncrhonization task test cases.""" + + @pytest.mark.asyncio + async def test_normal_operation(self, mocker): + """Test that the task works properly under normal circumstances.""" + counter = mocker.Mock(spec=Counter) + counters = [ + Counter.CountPerFeature('f1', 123, 2), + Counter.CountPerFeature('f2', 123, 123), + Counter.CountPerFeature('f1', 456, 111), + Counter.CountPerFeature('f2', 456, 222) + ] + self._pop_called = 0 + def pop_all(): + self._pop_called += 1 + return counters + counter.pop_all = pop_all + + api = mocker.Mock(spec=ImpressionsAPI) + self.flushed = None + self.called = 0 + async def flush_counters(imps): + self.called += 1 + self.flushed = imps + return HttpResponse(200, '', {}) + api.flush_counters = flush_counters + + impressions_sync.ImpressionsCountSyncTaskAsync._PERIOD = 1 + impression_synchronizer = ImpressionsCountSynchronizerAsync(api, counter) + task = impressions_sync.ImpressionsCountSyncTaskAsync( + impression_synchronizer.synchronize_counters + ) + task.start() + await asyncio.sleep(2) + assert task.is_running() + + assert self._pop_called == 1 + assert self.flushed == counters + + calls_now = self.called + await task.stop() + assert self.called > calls_now diff --git a/tests/tasks/test_segment_sync.py b/tests/tasks/test_segment_sync.py index 91482a40..cc701e52 100644 --- a/tests/tasks/test_segment_sync.py +++ b/tests/tasks/test_segment_sync.py @@ -2,15 +2,17 @@ import threading import time +import pytest + from splitio.api.commons import FetchOptions from splitio.tasks import segment_sync -from splitio.storage import SegmentStorage, SplitStorage +from splitio.storage import SegmentStorage, SplitStorage, RuleBasedSegmentsStorage from splitio.models.splits import Split from splitio.models.segments import Segment from splitio.models.grammar.condition import Condition from splitio.models.grammar.matchers import UserDefinedSegmentMatcher -from splitio.sync.segment import SegmentSynchronizer - +from splitio.sync.segment import SegmentSynchronizer, SegmentSynchronizerAsync +from splitio.optional.loaders import asyncio class SegmentSynchronizationTests(object): """Split synchronization task test cases.""" @@ -19,6 +21,8 @@ def test_normal_operation(self, mocker): """Test the normal operation flow.""" split_storage = mocker.Mock(spec=SplitStorage) split_storage.get_segment_names.return_value = ['segmentA', 'segmentB', 'segmentC'] + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + rbs_storage.get_segment_names.return_value = [] # Setup a mocked segment storage whose changenumber returns -1 on first fetch and # 123 afterwards. @@ -60,10 +64,10 @@ def fetch_segment_mock(segment_name, change_number, fetch_options): fetch_segment_mock._count_c = 0 api = mocker.Mock() - fetch_options = FetchOptions(True) + fetch_options = FetchOptions(True, None, None, None, None) api.fetch_segment.side_effect = fetch_segment_mock - segments_synchronizer = SegmentSynchronizer(api, split_storage, storage) + segments_synchronizer = SegmentSynchronizer(api, split_storage, storage, rbs_storage) task = segment_sync.SegmentSynchronizationTask(segments_synchronizer.synchronize_segments, 0.5) task.start() @@ -95,4 +99,276 @@ def fetch_segment_mock(segment_name, change_number, fetch_options): def test_that_errors_dont_stop_task(self, mocker): """Test that if fetching segments fails at some_point, the task will continue running.""" - # TODO! + split_storage = mocker.Mock(spec=SplitStorage) + split_storage.get_segment_names.return_value = ['segmentA', 'segmentB', 'segmentC'] + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + rbs_storage.get_segment_names.return_value = [] + + # Setup a mocked segment storage whose changenumber returns -1 on first fetch and + # 123 afterwards. + storage = mocker.Mock(spec=SegmentStorage) + + def change_number_mock(segment_name): + if segment_name == 'segmentA' and change_number_mock._count_a == 0: + change_number_mock._count_a = 1 + return -1 + if segment_name == 'segmentB' and change_number_mock._count_b == 0: + change_number_mock._count_b = 1 + return -1 + if segment_name == 'segmentC' and change_number_mock._count_c == 0: + change_number_mock._count_c = 1 + return -1 + return 123 + change_number_mock._count_a = 0 + change_number_mock._count_b = 0 + change_number_mock._count_c = 0 + storage.get_change_number.side_effect = change_number_mock + + # Setup a mocked segment api to return segments mentioned before. + def fetch_segment_mock(segment_name, change_number, fetch_options): + if segment_name == 'segmentA' and fetch_segment_mock._count_a == 0: + fetch_segment_mock._count_a = 1 + return {'name': 'segmentA', 'added': ['key1', 'key2', 'key3'], 'removed': [], + 'since': -1, 'till': 123} + if segment_name == 'segmentB' and fetch_segment_mock._count_b == 0: + fetch_segment_mock._count_b = 1 + raise Exception("some exception") + if segment_name == 'segmentC' and fetch_segment_mock._count_c == 0: + fetch_segment_mock._count_c = 1 + return {'name': 'segmentC', 'added': ['key7', 'key8', 'key9'], 'removed': [], + 'since': -1, 'till': 123} + return {'added': [], 'removed': [], 'since': 123, 'till': 123} + fetch_segment_mock._count_a = 0 + fetch_segment_mock._count_b = 0 + fetch_segment_mock._count_c = 0 + + api = mocker.Mock() + fetch_options = FetchOptions(True, None, None, None, None) + api.fetch_segment.side_effect = fetch_segment_mock + + segments_synchronizer = SegmentSynchronizer(api, split_storage, storage, rbs_storage) + task = segment_sync.SegmentSynchronizationTask(segments_synchronizer.synchronize_segments, + 0.5) + task.start() + time.sleep(0.7) + + assert task.is_running() + + stop_event = threading.Event() + task.stop(stop_event) + stop_event.wait() + assert not task.is_running() + + api_calls = [call for call in api.fetch_segment.mock_calls] + assert mocker.call('segmentA', -1, fetch_options) in api_calls + assert mocker.call('segmentB', -1, fetch_options) in api_calls + assert mocker.call('segmentC', -1, fetch_options) in api_calls + assert mocker.call('segmentA', 123, fetch_options) in api_calls + assert mocker.call('segmentC', 123, fetch_options) in api_calls + + segment_put_calls = storage.put.mock_calls + segments_to_validate = set(['segmentA', 'segmentB', 'segmentC']) + for call in segment_put_calls: + _, positional_args, _ = call + segment = positional_args[0] + assert isinstance(segment, Segment) + assert segment.name in segments_to_validate + segments_to_validate.remove(segment.name) + + +class SegmentSynchronizationAsyncTests(object): + """Split synchronization async task test cases.""" + + @pytest.mark.asyncio + async def test_normal_operation(self, mocker): + """Test the normal operation flow.""" + split_storage = mocker.Mock(spec=SplitStorage) + async def get_segment_names(): + return ['segmentA', 'segmentB', 'segmentC'] + split_storage.get_segment_names = get_segment_names + + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + async def get_segment_names_rbs(): + return [] + rbs_storage.get_segment_names = get_segment_names_rbs + + # Setup a mocked segment storage whose changenumber returns -1 on first fetch and + # 123 afterwards. + storage = mocker.Mock(spec=SegmentStorage) + + async def change_number_mock(segment_name): + if segment_name == 'segmentA' and change_number_mock._count_a == 0: + change_number_mock._count_a = 1 + return -1 + if segment_name == 'segmentB' and change_number_mock._count_b == 0: + change_number_mock._count_b = 1 + return -1 + if segment_name == 'segmentC' and change_number_mock._count_c == 0: + change_number_mock._count_c = 1 + return -1 + return 123 + change_number_mock._count_a = 0 + change_number_mock._count_b = 0 + change_number_mock._count_c = 0 + storage.get_change_number = change_number_mock + + self.segments = [] + async def put(segment): + self.segments.append(segment) + storage.put = put + + async def update(*arg): + pass + storage.update = update + + # Setup a mocked segment api to return segments mentioned before. + self.segment_name = [] + self.change_number = [] + self.fetch_options = [] + async def fetch_segment_mock(segment_name, change_number, fetch_options): + self.segment_name.append(segment_name) + self.change_number.append(change_number) + self.fetch_options.append(fetch_options) + if segment_name == 'segmentA' and fetch_segment_mock._count_a == 0: + fetch_segment_mock._count_a = 1 + return {'name': 'segmentA', 'added': ['key1', 'key2', 'key3'], 'removed': [], + 'since': -1, 'till': 123} + if segment_name == 'segmentB' and fetch_segment_mock._count_b == 0: + fetch_segment_mock._count_b = 1 + return {'name': 'segmentB', 'added': ['key4', 'key5', 'key6'], 'removed': [], + 'since': -1, 'till': 123} + if segment_name == 'segmentC' and fetch_segment_mock._count_c == 0: + fetch_segment_mock._count_c = 1 + return {'name': 'segmentC', 'added': ['key7', 'key8', 'key9'], 'removed': [], + 'since': -1, 'till': 123} + return {'added': [], 'removed': [], 'since': 123, 'till': 123} + fetch_segment_mock._count_a = 0 + fetch_segment_mock._count_b = 0 + fetch_segment_mock._count_c = 0 + + api = mocker.Mock() + fetch_options = FetchOptions(True, None, None, None, None) + api.fetch_segment = fetch_segment_mock + + segments_synchronizer = SegmentSynchronizerAsync(api, split_storage, storage, rbs_storage) + task = segment_sync.SegmentSynchronizationTaskAsync(segments_synchronizer.synchronize_segments, + 0.5) + task.start() + await asyncio.sleep(0.7) + assert task.is_running() + + await task.stop() + assert not task.is_running() + + api_calls = [] + for i in range(6): + api_calls.append((self.segment_name[i], self.change_number[i], self.fetch_options[i])) + + assert ('segmentA', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentA', 123, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentB', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentB', 123, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentC', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentC', 123, FetchOptions(True, None, None, None, None)) in api_calls + + segments_to_validate = set(['segmentA', 'segmentB', 'segmentC']) + for segment in self.segments: + assert isinstance(segment, Segment) + assert segment.name in segments_to_validate + segments_to_validate.remove(segment.name) + + @pytest.mark.asyncio + async def test_that_errors_dont_stop_task(self, mocker): + """Test that if fetching segments fails at some_point, the task will continue running.""" + split_storage = mocker.Mock(spec=SplitStorage) + async def get_segment_names(): + return ['segmentA', 'segmentB', 'segmentC'] + split_storage.get_segment_names = get_segment_names + + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + async def get_segment_names_rbs(): + return [] + rbs_storage.get_segment_names = get_segment_names_rbs + + # Setup a mocked segment storage whose changenumber returns -1 on first fetch and + # 123 afterwards. + storage = mocker.Mock(spec=SegmentStorage) + + async def change_number_mock(segment_name): + if segment_name == 'segmentA' and change_number_mock._count_a == 0: + change_number_mock._count_a = 1 + return -1 + if segment_name == 'segmentB' and change_number_mock._count_b == 0: + change_number_mock._count_b = 1 + return -1 + if segment_name == 'segmentC' and change_number_mock._count_c == 0: + change_number_mock._count_c = 1 + return -1 + return 123 + change_number_mock._count_a = 0 + change_number_mock._count_b = 0 + change_number_mock._count_c = 0 + storage.get_change_number = change_number_mock + + self.segments = [] + async def put(segment): + self.segments.append(segment) + storage.put = put + + async def update(*arg): + pass + storage.update = update + + # Setup a mocked segment api to return segments mentioned before. + self.segment_name = [] + self.change_number = [] + self.fetch_options = [] + async def fetch_segment_mock(segment_name, change_number, fetch_options): + self.segment_name.append(segment_name) + self.change_number.append(change_number) + self.fetch_options.append(fetch_options) + if segment_name == 'segmentA' and fetch_segment_mock._count_a == 0: + fetch_segment_mock._count_a = 1 + return {'name': 'segmentA', 'added': ['key1', 'key2', 'key3'], 'removed': [], + 'since': -1, 'till': 123} + if segment_name == 'segmentB' and fetch_segment_mock._count_b == 0: + fetch_segment_mock._count_b = 1 + raise Exception("some exception") + if segment_name == 'segmentC' and fetch_segment_mock._count_c == 0: + fetch_segment_mock._count_c = 1 + return {'name': 'segmentC', 'added': ['key7', 'key8', 'key9'], 'removed': [], + 'since': -1, 'till': 123} + return {'added': [], 'removed': [], 'since': 123, 'till': 123} + fetch_segment_mock._count_a = 0 + fetch_segment_mock._count_b = 0 + fetch_segment_mock._count_c = 0 + + api = mocker.Mock() + fetch_options = FetchOptions(True, None, None, None, None) + api.fetch_segment = fetch_segment_mock + + segments_synchronizer = SegmentSynchronizerAsync(api, split_storage, storage, rbs_storage) + task = segment_sync.SegmentSynchronizationTaskAsync(segments_synchronizer.synchronize_segments, + 0.5) + task.start() + await asyncio.sleep(0.7) + assert task.is_running() + + await task.stop() + assert not task.is_running() + + api_calls = [] + for i in range(5): + api_calls.append((self.segment_name[i], self.change_number[i], self.fetch_options[i])) + + assert ('segmentA', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentA', 123, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentB', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentC', -1, FetchOptions(True, None, None, None, None)) in api_calls + assert ('segmentC', 123, FetchOptions(True, None, None, None, None)) in api_calls + + segments_to_validate = set(['segmentA', 'segmentB', 'segmentC']) + for segment in self.segments: + assert isinstance(segment, Segment) + assert segment.name in segments_to_validate + segments_to_validate.remove(segment.name) diff --git a/tests/tasks/test_split_sync.py b/tests/tasks/test_split_sync.py index adc90724..c9a0c692 100644 --- a/tests/tasks/test_split_sync.py +++ b/tests/tasks/test_split_sync.py @@ -1,13 +1,50 @@ """Split syncrhonization task test module.""" - import threading import time +import pytest + from splitio.api import APIException from splitio.api.commons import FetchOptions from splitio.tasks import split_sync -from splitio.storage import SplitStorage +from splitio.storage import SplitStorage, RuleBasedSegmentsStorage from splitio.models.splits import Split -from splitio.sync.split import SplitSynchronizer +from splitio.sync.split import SplitSynchronizer, SplitSynchronizerAsync +from splitio.optional.loaders import asyncio + +splits = [{ + 'changeNumber': 123, + 'trafficTypeName': 'user', + 'name': 'some_name', + 'trafficAllocation': 100, + 'trafficAllocationSeed': 123456, + 'seed': 321654, + 'status': 'ACTIVE', + 'killed': False, + 'defaultTreatment': 'off', + 'algo': 2, + 'conditions': [ + { + 'partitions': [ + {'treatment': 'on', 'size': 50}, + {'treatment': 'off', 'size': 50} + ], + 'contitionType': 'WHITELIST', + 'label': 'some_label', + 'matcherGroup': { + 'matchers': [ + { + 'matcherType': 'WHITELIST', + 'whitelistMatcherData': { + 'whitelist': ['k1', 'k2', 'k3'] + }, + 'negate': False, + } + ], + 'combiner': 'AND' + } + } + ] +}] class SplitSynchronizationTests(object): @@ -16,6 +53,7 @@ class SplitSynchronizationTests(object): def test_normal_operation(self, mocker): """Test the normal operation flow.""" storage = mocker.Mock(spec=SplitStorage) + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) def change_number_mock(): change_number_mock._calls += 1 @@ -25,62 +63,41 @@ def change_number_mock(): change_number_mock._calls = 0 storage.get_change_number.side_effect = change_number_mock + class flag_set_filter(): + def should_filter(): + return False + + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + storage.flag_set_filter.sorted_flag_sets = [] + + self.clear = False + def clear(): + self.clear = True + storage.clear = clear + rbs_storage.clear = clear + api = mocker.Mock() - splits = [{ - 'changeNumber': 123, - 'trafficTypeName': 'user', - 'name': 'some_name', - 'trafficAllocation': 100, - 'trafficAllocationSeed': 123456, - 'seed': 321654, - 'status': 'ACTIVE', - 'killed': False, - 'defaultTreatment': 'off', - 'algo': 2, - 'conditions': [ - { - 'partitions': [ - {'treatment': 'on', 'size': 50}, - {'treatment': 'off', 'size': 50} - ], - 'contitionType': 'WHITELIST', - 'label': 'some_label', - 'matcherGroup': { - 'matchers': [ - { - 'matcherType': 'WHITELIST', - 'whitelistMatcherData': { - 'whitelist': ['k1', 'k2', 'k3'] - }, - 'negate': False, - } - ], - 'combiner': 'AND' - } - } - ] - }] def get_changes(*args, **kwargs): get_changes.called += 1 if get_changes.called == 1: - return { - 'splits': splits, - 'since': -1, - 'till': 123 + return {'ff': { + 'd': splits, + 's': -1, + 't': 123}, 'rbs': {'d': [], 't': -1, 's': -1} } else: - return { - 'splits': [], - 'since': 123, - 'till': 123 - } + return {'ff': {'d': [],'s': 123, 't': 123}, + 'rbs': {'d': [], 't': -1, 's': -1}} get_changes.called = 0 fetch_options = FetchOptions(True) api.fetch_splits.side_effect = get_changes - split_synchronizer = SplitSynchronizer(api, storage) + split_synchronizer = SplitSynchronizer(api, storage, rbs_storage) task = split_sync.SplitSynchronizationTask(split_synchronizer.synchronize_splits, 0.5) task.start() time.sleep(0.7) @@ -89,30 +106,35 @@ def get_changes(*args, **kwargs): task.stop(stop_event) stop_event.wait() assert not task.is_running() - assert mocker.call(-1, fetch_options) in api.fetch_splits.mock_calls - assert mocker.call(123, fetch_options) in api.fetch_splits.mock_calls + assert api.fetch_splits.mock_calls[0][1][0] == -1 + assert api.fetch_splits.mock_calls[0][1][2].cache_control_headers == True + assert api.fetch_splits.mock_calls[1][1][0] == 123 + assert api.fetch_splits.mock_calls[1][1][2].cache_control_headers == True - inserted_split = storage.put.mock_calls[0][1][0] + inserted_split = storage.update.mock_calls[0][1][0][0] assert isinstance(inserted_split, Split) assert inserted_split.name == 'some_name' def test_that_errors_dont_stop_task(self, mocker): """Test that if fetching splits fails at some_point, the task will continue running.""" storage = mocker.Mock(spec=SplitStorage) + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) api = mocker.Mock() def run(x): run._calls += 1 if run._calls == 1: - return {'splits': [], 'since': -1, 'till': -1} + return {'ff': {'d': [],'s': -1, 't': -1}, + 'rbs': {'d': [], 't': -1, 's': -1}} if run._calls == 2: - return {'splits': [], 'since': -1, 'till': -1} + return {'ff': {'d': [],'s': -1, 't': -1}, + 'rbs': {'d': [], 't': -1, 's': -1}} raise APIException("something broke") run._calls = 0 api.fetch_splits.side_effect = run storage.get_change_number.return_value = -1 - split_synchronizer = SplitSynchronizer(api, storage) + split_synchronizer = SplitSynchronizer(api, storage, rbs_storage) task = split_sync.SplitSynchronizationTask(split_synchronizer.synchronize_splits, 0.5) task.start() time.sleep(0.1) @@ -120,3 +142,114 @@ def run(x): time.sleep(1) assert task.is_running() task.stop() + + +class SplitSynchronizationAsyncTests(object): + """Split synchronization task async test cases.""" + + @pytest.mark.asyncio + async def test_normal_operation(self, mocker): + """Test the normal operation flow.""" + storage = mocker.Mock(spec=SplitStorage) + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + + async def change_number_mock(): + change_number_mock._calls += 1 + if change_number_mock._calls == 1: + return -1 + return 123 + change_number_mock._calls = 0 + storage.get_change_number = change_number_mock + async def rb_change_number_mock(): + return -1 + rbs_storage.get_change_number = rb_change_number_mock + + class flag_set_filter(): + def should_filter(): + return False + + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + storage.flag_set_filter.sorted_flag_sets = [] + + async def set_change_number(*_): + pass + change_number_mock._calls = 0 + storage.set_change_number = set_change_number + + self.clear = False + async def clear(): + self.clear = True + storage.clear = clear + rbs_storage.clear = clear + + api = mocker.Mock() + self.change_number = [] + self.fetch_options = [] + async def get_changes(change_number, rb_change_number, fetch_options): + self.change_number.append(change_number) + self.fetch_options.append(fetch_options) + get_changes.called += 1 + if get_changes.called == 1: + return {'ff': {'d': splits,'s': -1, 't': 123}, + 'rbs': {'d': [], 't': -1, 's': -1}} + else: + return {'ff': {'d': [],'s': 123, 't': 123}, + 'rbs': {'d': [], 't': -1, 's': -1}} + api.fetch_splits = get_changes + get_changes.called = 0 + self.inserted_split = None + async def update(split, deleted, change_number): + if len(split) > 0: + self.inserted_split = split + storage.update = update + async def rbs_update(split, deleted, change_number): + pass + rbs_storage.update = rbs_update + + fetch_options = FetchOptions(True) + split_synchronizer = SplitSynchronizerAsync(api, storage, rbs_storage) + task = split_sync.SplitSynchronizationTaskAsync(split_synchronizer.synchronize_splits, 0.5) + task.start() + await asyncio.sleep(2) + assert task.is_running() + await task.stop() + assert not task.is_running() + assert (self.change_number[0], self.fetch_options[0].cache_control_headers) == (-1, fetch_options.cache_control_headers) + assert (self.change_number[1], self.fetch_options[1].cache_control_headers, self.fetch_options[1].change_number) == (123, fetch_options.cache_control_headers, fetch_options.change_number) + assert isinstance(self.inserted_split[0], Split) + assert self.inserted_split[0].name == 'some_name' + + @pytest.mark.asyncio + async def test_that_errors_dont_stop_task(self, mocker): + """Test that if fetching splits fails at some_point, the task will continue running.""" + storage = mocker.Mock(spec=SplitStorage) + rbs_storage = mocker.Mock(spec=RuleBasedSegmentsStorage) + api = mocker.Mock() + + async def run(x): + run._calls += 1 + if run._calls == 1: + return {'ff': {'d': [],'s': -1, 't': -1}, + 'rbs': {'d': [], 't': -1, 's': -1}} + if run._calls == 2: + return {'ff': {'d': [],'s': -1, 't': -1}, + 'rbs': {'d': [], 't': -1, 's': -1}} + raise APIException("something broke") + run._calls = 0 + api.fetch_splits = run + + async def get_change_number(): + return -1 + storage.get_change_number = get_change_number + + split_synchronizer = SplitSynchronizerAsync(api, storage, rbs_storage) + task = split_sync.SplitSynchronizationTaskAsync(split_synchronizer.synchronize_splits, 0.5) + task.start() + await asyncio.sleep(0.1) + assert task.is_running() + await asyncio.sleep(1) + assert task.is_running() + await task.stop() diff --git a/tests/tasks/test_telemetry_sync.py b/tests/tasks/test_telemetry_sync.py new file mode 100644 index 00000000..21a887d0 --- /dev/null +++ b/tests/tasks/test_telemetry_sync.py @@ -0,0 +1,67 @@ +"""Impressions synchronization task test module.""" +import pytest +import threading +import time +from splitio.api.client import HttpResponse +from splitio.tasks.telemetry_sync import TelemetrySyncTask, TelemetrySyncTaskAsync +from splitio.api.telemetry import TelemetryAPI, TelemetryAPIAsync +from splitio.sync.telemetry import TelemetrySynchronizer, TelemetrySynchronizerAsync, InMemoryTelemetrySubmitter, InMemoryTelemetrySubmitterAsync +from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemoryTelemetryStorageAsync +from splitio.engine.telemetry import TelemetryStorageConsumer, TelemetryStorageConsumerAsync +from splitio.optional.loaders import asyncio + + +class TelemetrySyncTaskTests(object): + """Unique Keys Syncrhonization task test cases.""" + + def test_record_stats(self, mocker): + """Test that the task works properly under normal circumstances.""" + api = mocker.Mock(spec=TelemetryAPI) + api.record_stats.return_value = HttpResponse(200, '', {}) + telemetry_storage = InMemoryTelemetryStorage() + telemetry_consumer = TelemetryStorageConsumer(telemetry_storage) + telemetry_submitter = InMemoryTelemetrySubmitter(telemetry_consumer, mocker.Mock(), mocker.Mock(), api) + def _build_stats(): + return {} + telemetry_submitter._build_stats = _build_stats + + telemetry_synchronizer = TelemetrySynchronizer(telemetry_submitter) + task = TelemetrySyncTask(telemetry_synchronizer.synchronize_stats, 1) + task.start() + time.sleep(2) + assert task.is_running() + assert len(api.record_stats.mock_calls) >= 1 + stop_event = threading.Event() + task.stop(stop_event) + stop_event.wait(5) + assert stop_event.is_set() + + +class TelemetrySyncTaskAsyncTests(object): + """Unique Keys Syncrhonization task test cases.""" + + @pytest.mark.asyncio + async def test_record_stats(self, mocker): + """Test that the task works properly under normal circumstances.""" + api = mocker.Mock(spec=TelemetryAPIAsync) + self.called = False + async def record_stats(stats): + self.called = True + return HttpResponse(200, '', {}) + api.record_stats = record_stats + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_consumer = TelemetryStorageConsumerAsync(telemetry_storage) + telemetry_submitter = InMemoryTelemetrySubmitterAsync(telemetry_consumer, mocker.Mock(), mocker.Mock(), api) + async def _build_stats(): + return {} + telemetry_submitter._build_stats = _build_stats + + telemetry_synchronizer = TelemetrySynchronizerAsync(telemetry_submitter) + task = TelemetrySyncTaskAsync(telemetry_synchronizer.synchronize_stats, 1) + task.start() + await asyncio.sleep(2) + assert task.is_running() + assert self.called + await task.stop() + assert not task.is_running() diff --git a/tests/tasks/test_unique_keys_sync.py b/tests/tasks/test_unique_keys_sync.py new file mode 100644 index 00000000..d04f9271 --- /dev/null +++ b/tests/tasks/test_unique_keys_sync.py @@ -0,0 +1,102 @@ +"""Impressions synchronization task test module.""" +import asyncio +import threading +import time +import pytest + +from splitio.api.client import HttpResponse +from splitio.tasks.unique_keys_sync import UniqueKeysSyncTask, ClearFilterSyncTask,\ + ClearFilterSyncTaskAsync, UniqueKeysSyncTaskAsync +from splitio.api.telemetry import TelemetryAPI +from splitio.sync.unique_keys import UniqueKeysSynchronizer, ClearFilterSynchronizer,\ + UniqueKeysSynchronizerAsync, ClearFilterSynchronizerAsync +from splitio.engine.impressions.unique_keys_tracker import UniqueKeysTracker, UniqueKeysTrackerAsync + + +class UniqueKeysSyncTests(object): + """Unique Keys Syncrhonization task test cases.""" + + def test_normal_operation(self, mocker): + """Test that the task works properly under normal circumstances.""" + api = mocker.Mock(spec=TelemetryAPI) + api.record_unique_keys.return_value = HttpResponse(200, '', {}) + + unique_keys_tracker = UniqueKeysTracker() + unique_keys_tracker.track("key1", "split1") + unique_keys_tracker.track("key2", "split1") + + unique_keys_sync = UniqueKeysSynchronizer(mocker.Mock(), unique_keys_tracker) + task = UniqueKeysSyncTask(unique_keys_sync.send_all, 1) + task.start() + time.sleep(2) + assert task.is_running() + assert api.record_unique_keys.mock_calls == mocker.call() + stop_event = threading.Event() + task.stop(stop_event) + stop_event.wait(5) + assert stop_event.is_set() + +class ClearFilterSyncTests(object): + """Clear Filter Syncrhonization task test cases.""" + + def test_normal_operation(self, mocker): + """Test that the task works properly under normal circumstances.""" + + unique_keys_tracker = UniqueKeysTracker() + unique_keys_tracker.track("key1", "split1") + unique_keys_tracker.track("key2", "split1") + + clear_filter_sync = ClearFilterSynchronizer(unique_keys_tracker) + task = ClearFilterSyncTask(clear_filter_sync.clear_all, 1) + task.start() + time.sleep(2) + assert task.is_running() + assert not unique_keys_tracker._filter.contains("split1key1") + assert not unique_keys_tracker._filter.contains("split1key2") + stop_event = threading.Event() + task.stop(stop_event) + stop_event.wait(5) + assert stop_event.is_set() + +class UniqueKeysSyncAsyncTests(object): + """Unique Keys Syncrhonization task test cases.""" + + @pytest.mark.asyncio + async def test_normal_operation(self, mocker): + """Test that the task works properly under normal circumstances.""" + api = mocker.Mock(spec=TelemetryAPI) + api.record_unique_keys.return_value = HttpResponse(200, '', {}) + + unique_keys_tracker = UniqueKeysTrackerAsync() + await unique_keys_tracker.track("key1", "split1") + await unique_keys_tracker.track("key2", "split1") + + unique_keys_sync = UniqueKeysSynchronizerAsync(mocker.Mock(), unique_keys_tracker) + task = UniqueKeysSyncTaskAsync(unique_keys_sync.send_all, 1) + task.start() + await asyncio.sleep(2) + assert task.is_running() + assert api.record_unique_keys.mock_calls == mocker.call() + await task.stop() + assert not task.is_running() + +class ClearFilterSyncTests(object): + """Clear Filter Syncrhonization task test cases.""" + + @pytest.mark.asyncio + async def test_normal_operation(self, mocker): + """Test that the task works properly under normal circumstances.""" + + unique_keys_tracker = UniqueKeysTrackerAsync() + await unique_keys_tracker.track("key1", "split1") + await unique_keys_tracker.track("key2", "split1") + + clear_filter_sync = ClearFilterSynchronizerAsync(unique_keys_tracker) + task = ClearFilterSyncTaskAsync(clear_filter_sync.clear_all, 1) + task.start() + await asyncio.sleep(2) + assert task.is_running() + assert not unique_keys_tracker._filter.contains("split1key1") + assert not unique_keys_tracker._filter.contains("split1key2") + await task.stop() + assert not task.is_running() diff --git a/tests/tasks/util/test_asynctask.py b/tests/tasks/util/test_asynctask.py index a22b4b45..b587b9c5 100644 --- a/tests/tasks/util/test_asynctask.py +++ b/tests/tasks/util/test_asynctask.py @@ -2,8 +2,10 @@ import time import threading -from splitio.tasks.util import asynctask +import pytest +from splitio.tasks.util import asynctask +from splitio.optional.loaders import asyncio class AsyncTaskTests(object): """AsyncTask test cases.""" @@ -90,7 +92,7 @@ def raise_exception(): task.stop(on_stop_event) on_stop_event.wait(1) - assert on_stop_event.isSet() + assert on_stop_event.is_set() assert on_init.mock_calls == [mocker.call()] assert on_stop.mock_calls == [mocker.call()] assert 9 <= len(main_func.mock_calls) <= 10 @@ -111,8 +113,145 @@ def test_force_run(self, mocker): task.stop(on_stop_event) on_stop_event.wait(1) - assert on_stop_event.isSet() + assert on_stop_event.is_set() assert on_init.mock_calls == [mocker.call()] assert on_stop.mock_calls == [mocker.call()] assert len(main_func.mock_calls) == 2 assert not task.running() + + +class AsyncTaskAsyncTests(object): + """AsyncTask test cases.""" + + @pytest.mark.asyncio + async def test_default_task_flow(self, mocker): + """Test the default execution flow of an asynctask.""" + self.main_called = 0 + async def main_func(): + self.main_called += 1 + + self.init_called = 0 + async def on_init(): + self.init_called += 1 + + self.stop_called = 0 + async def on_stop(): + self.stop_called += 1 + + task = asynctask.AsyncTaskAsync(main_func, 0.5, on_init, on_stop) + task.start() + await asyncio.sleep(1) + assert task.running() + await task.stop(True) + + assert 0 < self.main_called <= 2 + assert self.init_called == 1 + assert self.stop_called == 1 + assert not task.running() + + @pytest.mark.asyncio + async def test_main_exception_skips_iteration(self, mocker): + """Test that an exception in the main func only skips current iteration.""" + self.main_called = 0 + async def raise_exception(): + self.main_called += 1 + raise Exception('something') + main_func = raise_exception + + self.init_called = 0 + async def on_init(): + self.init_called += 1 + + self.stop_called = 0 + async def on_stop(): + self.stop_called += 1 + + task = asynctask.AsyncTaskAsync(main_func, 0.1, on_init, on_stop) + task.start() + await asyncio.sleep(1) + assert task.running() + await task.stop(True) + + assert 9 <= self.main_called <= 10 + assert self.init_called == 1 + assert self.stop_called == 1 + assert not task.running() + + @pytest.mark.asyncio + async def test_on_init_failure_aborts_task(self, mocker): + """Test that if the on_init callback fails, the task never runs.""" + self.main_called = 0 + async def main_func(): + self.main_called += 1 + + self.init_called = 0 + async def on_init(): + self.init_called += 1 + raise Exception('something') + + self.stop_called = 0 + async def on_stop(): + self.stop_called += 1 + + task = asynctask.AsyncTaskAsync(main_func, 0.1, on_init, on_stop) + task.start() + await asyncio.sleep(0.5) + assert not task.running() # Since on_init fails, task never starts + await task.stop(True) + + assert self.init_called == 1 + assert self.stop_called == 1 + assert self.main_called == 0 + assert not task.running() + + @pytest.mark.asyncio + async def test_on_stop_failure_ends_gacefully(self, mocker): + """Test that if the on_init callback fails, the task never runs.""" + self.main_called = 0 + async def main_func(): + self.main_called += 1 + + self.init_called = 0 + async def on_init(): + self.init_called += 1 + + self.stop_called = 0 + async def on_stop(): + self.stop_called += 1 + raise Exception('something') + + task = asynctask.AsyncTaskAsync(main_func, 0.1, on_init, on_stop) + task.start() + await asyncio.sleep(1) + await task.stop(True) + assert 9 <= self.main_called <= 10 + assert self.init_called == 1 + assert self.stop_called == 1 + + @pytest.mark.asyncio + async def test_force_run(self, mocker): + """Test that if the on_init callback fails, the task never runs.""" + self.main_called = 0 + async def main_func(): + self.main_called += 1 + + self.init_called = 0 + async def on_init(): + self.init_called += 1 + + self.stop_called = 0 + async def on_stop(): + self.stop_called += 1 + + task = asynctask.AsyncTaskAsync(main_func, 5, on_init, on_stop) + task.start() + await asyncio.sleep(1) + assert task.running() + task.force_execution() + task.force_execution() + await task.stop(True) + + assert self.main_called == 2 + assert self.init_called == 1 + assert self.stop_called == 1 + assert not task.running() diff --git a/tests/tasks/util/test_workerpool.py b/tests/tasks/util/test_workerpool.py index ab126a17..2f7a8e71 100644 --- a/tests/tasks/util/test_workerpool.py +++ b/tests/tasks/util/test_workerpool.py @@ -2,8 +2,10 @@ # pylint: disable=no-self-use,too-few-public-methods,missing-docstring import time import threading -from splitio.tasks.util import workerpool +import pytest +from splitio.tasks.util import workerpool +from splitio.optional.loaders import asyncio class WorkerPoolTests(object): """Worker pool test cases.""" @@ -71,3 +73,79 @@ def do_work(self, work): wpool.wait_for_completion() assert len(worker.worked) == 100 + + +class WorkerPoolAsyncTests(object): + """Worker pool async test cases.""" + + @pytest.mark.asyncio + async def test_normal_operation(self, mocker): + """Test normal opeation works properly.""" + self.calls = 0 + calls = [] + async def worker_func(num): + self.calls += 1 + calls.append(num) + + wpool = workerpool.WorkerPoolAsync(10, worker_func) + wpool.start() + jobs = [] + for num in range(0, 11): + jobs.append(str(num)) + + task = await wpool.submit_work(jobs) + assert await task.await_completion() + await wpool.stop() + for num in range(0, 11): + assert str(num) in calls + + @pytest.mark.asyncio + async def test_fail_in_msg_doesnt_break(self): + """Test that if a message cannot be parsed it is ignored and others are processed.""" + class Worker(object): #pylint: disable= + def __init__(self): + self.worked = set() + + async def do_work(self, work): + if work == '55': + raise Exception('something') + self.worked.add(work) + + worker = Worker() + wpool = workerpool.WorkerPoolAsync(50, worker.do_work) + wpool.start() + jobs = [] + for num in range(0, 100): + jobs.append(str(num)) + task = await wpool.submit_work(jobs) + + assert not await task.await_completion() + await wpool.stop() + + for num in range(0, 100): + if num != 55: + assert str(num) in worker.worked + else: + assert str(num) not in worker.worked + + @pytest.mark.asyncio + async def test_msg_acked_after_processed(self): + """Test that events are only set after all the work in the pipeline is done.""" + class Worker(object): + def __init__(self): + self.worked = set() + + async def do_work(self, work): + self.worked.add(work) + await asyncio.sleep(0.02) # will wait 2 seconds in total for 100 elements + + worker = Worker() + wpool = workerpool.WorkerPoolAsync(50, worker.do_work) + wpool.start() + jobs = [] + for num in range(0, 100): + jobs.append(str(num)) + task = await wpool.submit_work(jobs) + assert await task.await_completion() + await wpool.stop() + assert len(worker.worked) == 100 diff --git a/tests/util/test_storage_helper.py b/tests/util/test_storage_helper.py new file mode 100644 index 00000000..60e83e8c --- /dev/null +++ b/tests/util/test_storage_helper.py @@ -0,0 +1,339 @@ +"""Storage Helper tests.""" +import pytest +import queue +import asyncio + +from splitio.util.storage_helper import update_feature_flag_storage, get_valid_flag_sets, combine_valid_flag_sets, \ + update_rule_based_segment_storage, update_rule_based_segment_storage_async, update_feature_flag_storage_async, \ + get_standard_segment_names_in_rbs_storage_async, get_standard_segment_names_in_rbs_storage +from splitio.storage.inmemmory import InMemorySplitStorage, InMemoryRuleBasedSegmentStorage, InMemoryRuleBasedSegmentStorageAsync, \ + InMemorySplitStorageAsync +from splitio.models import splits, rule_based_segments +from splitio.storage import FlagSetsFilter +from tests.sync.test_splits_synchronizer import splits_raw as split_sample + +class StorageHelperTests(object): + + rbs = rule_based_segments.from_raw({ + "changeNumber": 123, + "name": "sample_rule_based_segment", + "status": "ACTIVE", + "trafficTypeName": "user", + "excluded":{ + "keys":["mauro@split.io","gaston@split.io"], + "segments":[{"name":"excluded_segment", "type": "standard"}] + }, + "conditions": [ + {"matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SEGMENT", + "negate": False, + "userDefinedSegmentMatcherData": { + "segmentName": "employees" + }, + "whitelistMatcherData": None + } + ] + }, + } + ] + }) + + def test_update_feature_flag_storage(self, mocker): + storage = mocker.Mock(spec=InMemorySplitStorage) + split = splits.from_raw(split_sample[0]) + + self.added = [] + self.deleted = [] + self.change_number = 0 + def update(to_add, to_delete, change_number): + self.added = to_add + self.deleted = to_delete + self.change_number = change_number + storage.update = update + + def is_flag_set_exist(flag_set): + return False + storage.is_flag_set_exist = is_flag_set_exist + + class flag_set_filter(): + def should_filter(): + return False + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + + self.clear = 0 + def clear(): + self.clear += 1 + storage.clear = clear + + update_feature_flag_storage(storage, [split], 123, True) + assert self.added[0] == split + assert self.deleted == [] + assert self.change_number == 123 + assert self.clear == 1 + + class flag_set_filter2(): + def should_filter(): + return True + def intersect(sets): + return False + storage.flag_set_filter = flag_set_filter2 + storage.flag_set_filter.flag_sets = set({'set1', 'set2'}) + + self.clear = 0 + update_feature_flag_storage(storage, [split], 123) + assert self.added == [] + assert self.deleted[0] == split.name + assert self.clear == 0 + + class flag_set_filter3(): + def should_filter(): + return True + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter3 + storage.flag_set_filter.flag_sets = set({'set1', 'set2'}) + + def is_flag_set_exist2(flag_set): + return True + storage.is_flag_set_exist = is_flag_set_exist2 + update_feature_flag_storage(storage, [split], 123) + assert self.added[0] == split + assert self.deleted == [] + + split_json = split_sample[0] + split_json['conditions'].append({ + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SEGMENT", + "negate": False, + "userDefinedSegmentMatcherData": { + "segmentName": "segment1" + }, + "whitelistMatcherData": None + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 30 + }, + { + "treatment": "off", + "size": 70 + } + ] + } + ) + + split = splits.from_raw(split_json) + storage.config_flag_sets_used = 0 + assert update_feature_flag_storage(storage, [split], 123) == {'segment1'} + + def test_get_valid_flag_sets(self): + flag_sets = ['set1', 'set2'] + config_flag_sets = FlagSetsFilter([]) + assert get_valid_flag_sets(flag_sets, config_flag_sets) == ['set1', 'set2'] + + config_flag_sets = FlagSetsFilter(['set1']) + assert get_valid_flag_sets(flag_sets, config_flag_sets) == ['set1'] + + flag_sets = ['set2', 'set3'] + config_flag_sets = FlagSetsFilter(['set1', 'set2']) + assert get_valid_flag_sets(flag_sets, config_flag_sets) == ['set2'] + + flag_sets = ['set3', 'set4'] + config_flag_sets = FlagSetsFilter(['set1', 'set2']) + assert get_valid_flag_sets(flag_sets, config_flag_sets) == [] + + flag_sets = [] + config_flag_sets = FlagSetsFilter(['set1', 'set2']) + assert get_valid_flag_sets(flag_sets, config_flag_sets) == [] + + def test_combine_valid_flag_sets(self): + results_set = [{'set1', 'set2'}, {'set2', 'set3'}] + assert combine_valid_flag_sets(results_set) == {'set1', 'set2', 'set3'} + + results_set = [{}, {'set2', 'set3'}] + assert combine_valid_flag_sets(results_set) == {'set2', 'set3'} + + results_set = ['set1', {'set2', 'set3'}] + assert combine_valid_flag_sets(results_set) == {'set2', 'set3'} + + def test_update_rule_base_segment_storage(self, mocker): + storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorage) + self.added = [] + self.deleted = [] + self.change_number = 0 + def update(to_add, to_delete, change_number): + self.added = to_add + self.deleted = to_delete + self.change_number = change_number + storage.update = update + + self.clear = 0 + def clear(): + self.clear += 1 + storage.clear = clear + + segments = update_rule_based_segment_storage(storage, [self.rbs], 123) + assert self.added[0] == self.rbs + assert self.deleted == [] + assert self.change_number == 123 + assert segments == {'excluded_segment', 'employees'} + assert self.clear == 0 + + segments = update_rule_based_segment_storage(storage, [self.rbs], 123, True) + assert self.clear == 1 + + def test_get_standard_segment_in_rbs_storage(self, mocker): + events_queue = queue.Queue() + storage = InMemoryRuleBasedSegmentStorage(events_queue) + segments = update_rule_based_segment_storage(storage, [self.rbs], 123) + assert get_standard_segment_names_in_rbs_storage(storage) == {'excluded_segment', 'employees'} + + @pytest.mark.asyncio + async def test_get_standard_segment_in_rbs_storage(self, mocker): + storage = InMemoryRuleBasedSegmentStorageAsync(asyncio.Queue()) + segments = await update_rule_based_segment_storage_async(storage, [self.rbs], 123) + assert await get_standard_segment_names_in_rbs_storage_async(storage) == {'excluded_segment', 'employees'} + + @pytest.mark.asyncio + async def test_update_rule_base_segment_storage_async(self, mocker): + storage = mocker.Mock(spec=InMemoryRuleBasedSegmentStorageAsync) + self.added = [] + self.deleted = [] + self.change_number = 0 + async def update(to_add, to_delete, change_number): + self.added = to_add + self.deleted = to_delete + self.change_number = change_number + storage.update = update + + self.clear = 0 + async def clear(): + self.clear += 1 + storage.clear = clear + + segments = await update_rule_based_segment_storage_async(storage, [self.rbs], 123) + assert self.added[0] == self.rbs + assert self.deleted == [] + assert self.change_number == 123 + assert segments == {'excluded_segment', 'employees'} + + segments = await update_rule_based_segment_storage_async(storage, [self.rbs], 123, True) + assert self.clear == 1 + + @pytest.mark.asyncio + async def test_update_feature_flag_storage_async(self, mocker): + storage = mocker.Mock(spec=InMemorySplitStorageAsync) + split = splits.from_raw(split_sample[0]) + + self.added = [] + self.deleted = [] + self.change_number = 0 + async def get(flag_name): + return None + storage.get = get + + async def update(to_add, to_delete, change_number): + self.added = to_add + self.deleted = to_delete + self.change_number = change_number + storage.update = update + + async def is_flag_set_exist(flag_set): + return False + storage.is_flag_set_exist = is_flag_set_exist + + class flag_set_filter(): + def should_filter(): + return False + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter + storage.flag_set_filter.flag_sets = {} + + self.clear = 0 + async def clear(): + self.clear += 1 + storage.clear = clear + + await update_feature_flag_storage_async(storage, [split], 123, True) + assert self.added[0] == split + assert self.deleted == [] + assert self.change_number == 123 + assert self.clear == 1 + + class flag_set_filter2(): + def should_filter(): + return True + def intersect(sets): + return False + storage.flag_set_filter = flag_set_filter2 + storage.flag_set_filter.flag_sets = set({'set1', 'set2'}) + + async def get(flag_name): + return split + storage.get = get + + self.clear = 0 + await update_feature_flag_storage_async(storage, [split], 123) + assert self.added == [] + assert self.deleted[0] == split.name + assert self.clear == 0 + + class flag_set_filter3(): + def should_filter(): + return True + def intersect(sets): + return True + storage.flag_set_filter = flag_set_filter3 + storage.flag_set_filter.flag_sets = set({'set1', 'set2'}) + + async def is_flag_set_exist2(flag_set): + return True + storage.is_flag_set_exist = is_flag_set_exist2 + await update_feature_flag_storage_async(storage, [split], 123) + assert self.added[0] == split + assert self.deleted == [] + + split_json = split_sample[0] + split_json['conditions'].append({ + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SEGMENT", + "negate": False, + "userDefinedSegmentMatcherData": { + "segmentName": "segment1" + }, + "whitelistMatcherData": None + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 30 + }, + { + "treatment": "off", + "size": 70 + } + ] + } + ) + + split = splits.from_raw(split_json) + storage.config_flag_sets_used = 0 + assert await update_feature_flag_storage_async(storage, [split], 123) == {'segment1'} \ No newline at end of file